mirror of
https://github.com/ml-explore/mlx.git
synced 2024-09-15 10:04:00 +02:00
parent
40b6d67333
commit
43ffdab172
|
@ -43,20 +43,22 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
|||
auto half_size = grid_dim.y - odd;
|
||||
out += index.x * bytes_per_key;
|
||||
bool drop_last = odd && (index.y == half_size);
|
||||
auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y);
|
||||
auto bits = threefry2x32_hash(key, count);
|
||||
auto bits = threefry2x32_hash(
|
||||
key, uint2(index.y, drop_last ? 0 : index.y + grid_dim.y));
|
||||
size_t idx = size_t(index.y) << 2;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[4 * count.x + i] = bits.bytes[0][i];
|
||||
out[idx + i] = bits.bytes[0][i];
|
||||
}
|
||||
if (!drop_last) {
|
||||
idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2;
|
||||
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
int edge_bytes = (bytes_per_key % 4);
|
||||
for (int i = 0; i < edge_bytes; ++i) {
|
||||
out[4 * count.y + i] = bits.bytes[1][i];
|
||||
out[idx + i] = bits.bytes[1][i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[4 * count.y + i] = bits.bytes[1][i];
|
||||
out[idx + i] = bits.bytes[1][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -77,22 +79,24 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
|||
auto k2_elem = elem_to_loc(kidx + 1, key_shape, key_strides, ndim);
|
||||
auto key = uint2(keys[k1_elem], keys[k2_elem]);
|
||||
auto half_size = grid_dim.y - odd;
|
||||
out += index.x * bytes_per_key;
|
||||
out += size_t(index.x) * bytes_per_key;
|
||||
bool drop_last = odd && (index.y == half_size);
|
||||
auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y);
|
||||
auto bits = threefry2x32_hash(key, count);
|
||||
auto bits = threefry2x32_hash(
|
||||
key, uint2(index.y, drop_last ? 0 : index.y + grid_dim.y));
|
||||
size_t idx = size_t(index.y) << 2;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[4 * count.x + i] = bits.bytes[0][i];
|
||||
out[idx + i] = bits.bytes[0][i];
|
||||
}
|
||||
if (!drop_last) {
|
||||
idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2;
|
||||
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
int edge_bytes = (bytes_per_key % 4);
|
||||
for (int i = 0; i < edge_bytes; ++i) {
|
||||
out[4 * count.y + i] = bits.bytes[1][i];
|
||||
out[idx + i] = bits.bytes[1][i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[4 * count.y + i] = bits.bytes[1][i];
|
||||
out[idx + i] = bits.bytes[1][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,36 +6,17 @@
|
|||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
[[kernel]] void rope(
|
||||
[[kernel]] void rope_single(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const int& offset,
|
||||
constant const float& base,
|
||||
constant const float& scale,
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Compute the input and output indices
|
||||
uint in_index_1, in_index_2;
|
||||
uint out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 =
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + grid.x * out_strides[2];
|
||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||
}
|
||||
|
||||
constant const size_t& stride,
|
||||
uint2 pos [[thread_position_in_grid]],
|
||||
uint2 grid [[threads_per_grid]]) {
|
||||
// Figure out L and d.
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
float L = scale * static_cast<float>(offset);
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
|
@ -43,6 +24,21 @@ template <typename T, bool traditional, bool forward>
|
|||
float costheta = metal::fast::cos(theta);
|
||||
float sintheta = metal::fast::sin(theta);
|
||||
|
||||
// Compute the input and output indices
|
||||
uint in_index_1, in_index_2;
|
||||
uint out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x + pos.y * stride;
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 = 2 * pos.x + pos.y * stride;
|
||||
in_index_2 = in_index_1 + 1;
|
||||
} else {
|
||||
out_index_1 = pos.x + pos.y * stride;
|
||||
out_index_2 = out_index_1 + grid.x;
|
||||
in_index_1 = pos.x + pos.y * stride;
|
||||
in_index_2 = in_index_1 + grid.x;
|
||||
}
|
||||
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
|
@ -59,19 +55,97 @@ template <typename T, bool traditional, bool forward>
|
|||
out[out_index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
#define instantiate_rope(name, type, traditional, forward) \
|
||||
template [[host_name("rope_" #name)]] [[kernel]] void \
|
||||
rope<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const int& offset, \
|
||||
constant const float& base, \
|
||||
constant const float& scale, \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
[[kernel]] void rope(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& base,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Figure out L and d.
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * metal::exp2(-d * base);
|
||||
float costheta = metal::fast::cos(theta);
|
||||
float sintheta = metal::fast::sin(theta);
|
||||
|
||||
// Compute the input and output indices
|
||||
size_t in_index_1, in_index_2;
|
||||
size_t out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 =
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + grid.x * out_strides[2];
|
||||
in_index_1 =
|
||||
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||
}
|
||||
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
float rx1;
|
||||
float rx2;
|
||||
if (forward) {
|
||||
rx1 = x1 * costheta - x2 * sintheta;
|
||||
rx2 = x1 * sintheta + x2 * costheta;
|
||||
} else {
|
||||
rx1 = x2 * sintheta + x1 * costheta;
|
||||
rx2 = x2 * costheta - x1 * sintheta;
|
||||
}
|
||||
out[out_index_1] = static_cast<T>(rx1);
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
in_index_1 += strides[0];
|
||||
in_index_2 += strides[0];
|
||||
out_index_1 += out_strides[0];
|
||||
out_index_2 += out_strides[0];
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_rope_g(name, type, traditional, forward) \
|
||||
template [[host_name("rope_" #name)]] [[kernel]] void \
|
||||
rope<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& base, \
|
||||
constant const float& scale, \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const size_t& n_batch, \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_rope_s(name, type, traditional, forward) \
|
||||
template [[host_name("rope_single_" #name)]] [[kernel]] void \
|
||||
rope_single<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& base, \
|
||||
constant const float& scale, \
|
||||
constant const size_t& stride, \
|
||||
uint2 pos [[thread_position_in_grid]], \
|
||||
uint2 grid [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_rope(name, type, traditional, forward) \
|
||||
instantiate_rope_s(name, type, traditional, forward) \
|
||||
instantiate_rope_g(name, type, traditional, forward)
|
||||
|
||||
// clang-format off
|
||||
instantiate_rope(traditional_float16, half, true, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
|
||||
|
@ -84,4 +158,4 @@ instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
|
|||
instantiate_rope(vjp_traditional_float32, float, true, false)
|
||||
instantiate_rope(vjp_float16, half, false, false)
|
||||
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
|
||||
instantiate_rope(vjp_float32, float, false, false) // clang-format on
|
||||
instantiate_rope(vjp_float32, float, false, false) // clang-format on
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
constexpr int n_per_thread = 4;
|
||||
|
||||
void RoPE::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
|
@ -62,8 +64,11 @@ void RoPE::eval_gpu(
|
|||
out_strides[1] = out.strides()[ndim - 2];
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
|
||||
// Special case for inference (single time step and contiguous)
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "rope_" << (forward_ ? "" : "vjp_")
|
||||
kname << "rope_" << (single ? "single_" : "") << (forward_ ? "" : "vjp_")
|
||||
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
@ -72,18 +77,28 @@ void RoPE::eval_gpu(
|
|||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(donated ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&offset_, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&base, sizeof(float), 5);
|
||||
compute_encoder->setBytes(&scale_, sizeof(float), 6);
|
||||
compute_encoder->setBytes(&offset_, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&base, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&scale_, sizeof(float), 4);
|
||||
|
||||
int dim0 = dims_ / 2;
|
||||
int dim1 = in.shape(-2);
|
||||
int dim2 = in.size() / mat_size;
|
||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
auto grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
size_t n_batch = in.size() / mat_size;
|
||||
if (single) {
|
||||
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 5);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
auto group_dims = get_block_dims(dim0, n_batch, 1);
|
||||
auto grid_dims = MTL::Size(dim0, n_batch, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&n_batch, sizeof(size_t), 7);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
uint32_t dim1 = in.shape(-2);
|
||||
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
|
||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
auto grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
|
Loading…
Reference in a new issue