mirror of
https://github.com/ml-explore/mlx.git
synced 2024-09-15 10:04:00 +02:00
feat: metal formatting and pre-commit bump (#1038)
* feat: metal formatting and pre-commit bump * add guards * update * more guards * more guards * smakk fix * Refactor instantiation of ternary types in ternary.metal * fix scan.metal
This commit is contained in:
parent
8db7161c94
commit
a30e7ed2da
|
@ -1,11 +1,11 @@
|
|||
repos:
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v18.1.3
|
||||
rev: v18.1.4
|
||||
hooks:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 24.3.0
|
||||
rev: 24.4.2
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
|
|
|
@ -33,7 +33,7 @@ array axpby(
|
|||
class Axpby : public Primitive {
|
||||
public:
|
||||
explicit Axpby(Stream stream, float alpha, float beta)
|
||||
: Primitive(stream), alpha_(alpha), beta_(beta){};
|
||||
: Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
|
||||
/**
|
||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||
|
|
|
@ -19,7 +19,7 @@ template <typename T>
|
|||
uint index [[thread_position_in_grid]]) {
|
||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
||||
out[index] =
|
||||
out[index] =
|
||||
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
||||
}
|
||||
|
||||
|
@ -31,30 +31,30 @@ template <typename T>
|
|||
constant const float& alpha [[buffer(3)]],
|
||||
constant const float& beta [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
out[index] =
|
||||
out[index] =
|
||||
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
|
||||
}
|
||||
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] \
|
||||
[[kernel]] void axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("axpby_contiguous_" #type_name)]] \
|
||||
[[kernel]] void axpby_contiguous<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
|
||||
axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
|
||||
axpby_contiguous<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
|
|
|
@ -14,7 +14,7 @@ class Buffer {
|
|||
void* ptr_;
|
||||
|
||||
public:
|
||||
Buffer(void* ptr) : ptr_(ptr){};
|
||||
Buffer(void* ptr) : ptr_(ptr) {};
|
||||
|
||||
// Get the raw data pointer from the buffer
|
||||
void* raw_ptr();
|
||||
|
|
|
@ -209,7 +209,7 @@ class array {
|
|||
allocator::Buffer buffer;
|
||||
deleter_t d;
|
||||
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
|
||||
: buffer(buffer), d(d){};
|
||||
: buffer(buffer), d(d) {};
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
Data& operator=(const Data& d) = delete;
|
||||
|
|
|
@ -38,7 +38,7 @@ using MTLFCList =
|
|||
|
||||
struct CommandEncoder {
|
||||
CommandEncoder(MTL::ComputeCommandEncoder* enc)
|
||||
: enc(enc), concurrent(false){};
|
||||
: enc(enc), concurrent(false) {};
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
|
|
|
@ -11,22 +11,22 @@ template <typename T>
|
|||
out[index] = start + index * step;
|
||||
}
|
||||
|
||||
#define instantiate_arange(tname, type) \
|
||||
template [[host_name("arange" #tname)]] \
|
||||
[[kernel]] void arange<type>( \
|
||||
constant const type& start, \
|
||||
constant const type& step, \
|
||||
device type* out, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
#define instantiate_arange(tname, type) \
|
||||
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
|
||||
constant const type& start, \
|
||||
constant const type& step, \
|
||||
device type* out, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_arange(uint8, uint8_t)
|
||||
// clang-format off
|
||||
instantiate_arange(uint8, uint8_t)
|
||||
instantiate_arange(uint16, uint16_t)
|
||||
instantiate_arange(uint32, uint32_t)
|
||||
instantiate_arange(uint32, uint32_t)
|
||||
instantiate_arange(uint64, uint64_t)
|
||||
instantiate_arange(int8, int8_t)
|
||||
instantiate_arange(int8, int8_t)
|
||||
instantiate_arange(int16, int16_t)
|
||||
instantiate_arange(int32, int32_t)
|
||||
instantiate_arange(int64, int64_t)
|
||||
instantiate_arange(float16, half)
|
||||
instantiate_arange(float32, float)
|
||||
instantiate_arange(bfloat16, bfloat16_t)
|
||||
instantiate_arange(bfloat16, bfloat16_t) // clang-format on
|
|
@ -18,7 +18,8 @@ struct ArgMin {
|
|||
static constexpr constant U init = Limits<U>::max;
|
||||
|
||||
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
|
||||
if (best.val > current.val || (best.val == current.val && best.index > current.index)) {
|
||||
if (best.val > current.val ||
|
||||
(best.val == current.val && best.index > current.index)) {
|
||||
return current;
|
||||
} else {
|
||||
return best;
|
||||
|
@ -26,11 +27,12 @@ struct ArgMin {
|
|||
}
|
||||
|
||||
template <int N>
|
||||
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
||||
for (int i=0; i<N; i++) {
|
||||
IndexValPair<U>
|
||||
reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (vals[i] < best.val) {
|
||||
best.val = vals[i];
|
||||
best.index = offset+i;
|
||||
best.index = offset + i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
|
@ -42,7 +44,8 @@ struct ArgMax {
|
|||
static constexpr constant U init = Limits<U>::min;
|
||||
|
||||
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
|
||||
if (best.val < current.val || (best.val == current.val && best.index > current.index)) {
|
||||
if (best.val < current.val ||
|
||||
(best.val == current.val && best.index > current.index)) {
|
||||
return current;
|
||||
} else {
|
||||
return best;
|
||||
|
@ -50,11 +53,12 @@ struct ArgMax {
|
|||
}
|
||||
|
||||
template <int N>
|
||||
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
||||
for (int i=0; i<N; i++) {
|
||||
IndexValPair<U>
|
||||
reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (vals[i] > best.val) {
|
||||
best.val = vals[i];
|
||||
best.index = offset+i;
|
||||
best.index = offset + i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
|
@ -64,19 +68,16 @@ struct ArgMax {
|
|||
template <typename U>
|
||||
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
|
||||
return IndexValPair<U>{
|
||||
simd_shuffle_down(data.index, delta),
|
||||
simd_shuffle_down(data.val, delta)
|
||||
};
|
||||
simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)};
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename Op, int N_READS>
|
||||
[[kernel]] void arg_reduce_general(
|
||||
const device T *in [[buffer(0)]],
|
||||
device uint32_t *out [[buffer(1)]],
|
||||
const device int *shape [[buffer(2)]],
|
||||
const device size_t *in_strides [[buffer(3)]],
|
||||
const device size_t *out_strides [[buffer(4)]],
|
||||
const device T* in [[buffer(0)]],
|
||||
device uint32_t* out [[buffer(1)]],
|
||||
const device int* shape [[buffer(2)]],
|
||||
const device size_t* in_strides [[buffer(3)]],
|
||||
const device size_t* out_strides [[buffer(4)]],
|
||||
const device size_t& ndim [[buffer(5)]],
|
||||
const device size_t& axis_stride [[buffer(6)]],
|
||||
const device size_t& axis_size [[buffer(7)]],
|
||||
|
@ -86,7 +87,6 @@ template <typename T, typename Op, int N_READS>
|
|||
uint simd_size [[threads_per_simdgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
// Shapes and strides *do not* contain the reduction axis. The reduction size
|
||||
// and stride are provided in axis_stride and axis_size.
|
||||
//
|
||||
|
@ -113,13 +113,13 @@ template <typename T, typename Op, int N_READS>
|
|||
threadgroup IndexValPair<T> local_data[32];
|
||||
|
||||
// Loop over the reduction axis in lsize*N_READS buckets
|
||||
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
|
||||
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
|
||||
// Read the current value
|
||||
uint32_t current_index = r*lsize*N_READS + lid*N_READS;
|
||||
uint32_t current_index = r * lsize * N_READS + lid * N_READS;
|
||||
uint32_t offset = current_index;
|
||||
const device T * current_in = in + in_idx + current_index * axis_stride;
|
||||
const device T* current_in = in + in_idx + current_index * axis_stride;
|
||||
T vals[N_READS];
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = (current_index < axis_size) ? *current_in : T(Op::init);
|
||||
current_index++;
|
||||
current_in += axis_stride;
|
||||
|
@ -130,7 +130,7 @@ template <typename T, typename Op, int N_READS>
|
|||
// need to reduce across the thread group.
|
||||
|
||||
// First per simd reduction.
|
||||
for (uint offset=simd_size/2; offset>0; offset/=2) {
|
||||
for (uint offset = simd_size / 2; offset > 0; offset /= 2) {
|
||||
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
|
||||
best = op.reduce(best, neighbor);
|
||||
}
|
||||
|
@ -149,7 +149,7 @@ template <typename T, typename Op, int N_READS>
|
|||
if (simd_lane_id < simd_groups) {
|
||||
best = local_data[simd_lane_id];
|
||||
}
|
||||
for (uint offset=simd_size/2; offset>0; offset/=2) {
|
||||
for (uint offset = simd_size / 2; offset > 0; offset /= 2) {
|
||||
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
|
||||
best = op.reduce(best, neighbor);
|
||||
}
|
||||
|
@ -161,24 +161,25 @@ template <typename T, typename Op, int N_READS>
|
|||
}
|
||||
|
||||
#define instantiate_arg_reduce_helper(name, itype, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void arg_reduce_general<itype, op<itype>, 4>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device uint32_t * out [[buffer(1)]], \
|
||||
const device int *shape [[buffer(2)]], \
|
||||
const device size_t *in_strides [[buffer(3)]], \
|
||||
const device size_t *out_strides [[buffer(4)]], \
|
||||
const device size_t& ndim [[buffer(5)]], \
|
||||
const device size_t& axis_stride [[buffer(6)]], \
|
||||
const device size_t& axis_size [[buffer(7)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_size [[threads_per_simdgroup]], \
|
||||
template [[host_name(name)]] [[kernel]] void \
|
||||
arg_reduce_general<itype, op<itype>, 4>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device uint32_t* out [[buffer(1)]], \
|
||||
const device int* shape [[buffer(2)]], \
|
||||
const device size_t* in_strides [[buffer(3)]], \
|
||||
const device size_t* out_strides [[buffer(4)]], \
|
||||
const device size_t& ndim [[buffer(5)]], \
|
||||
const device size_t& axis_stride [[buffer(6)]], \
|
||||
const device size_t& axis_size [[buffer(7)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_size [[threads_per_simdgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_arg_reduce(name, itype) \
|
||||
// clang-format off
|
||||
#define instantiate_arg_reduce(name, itype) \
|
||||
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
|
||||
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
|
||||
|
||||
|
@ -193,4 +194,4 @@ instantiate_arg_reduce(int32, int32_t)
|
|||
instantiate_arg_reduce(int64, int64_t)
|
||||
instantiate_arg_reduce(float16, half)
|
||||
instantiate_arg_reduce(float32, float)
|
||||
instantiate_arg_reduce(bfloat16, bfloat16_t)
|
||||
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on
|
|
@ -77,7 +77,8 @@ template <typename T, typename U, typename Op>
|
|||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
|
@ -92,7 +93,8 @@ template <typename T, typename U, typename Op, int DIM>
|
|||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
|
@ -112,114 +114,118 @@ template <typename T, typename U, typename Op>
|
|||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void binary_op_##bopt<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||
template \
|
||||
[[host_name(name)]] [[kernel]] void binary_op_##bopt<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
||||
template [[host_name(name "_" #dims)]] \
|
||||
[[kernel]] void binary_op_g_nd<itype, otype, op, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||
binary_op_g_nd<itype, otype, op, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
||||
template [[host_name(name "_1")]] \
|
||||
[[kernel]] void binary_op_g_nd1<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] \
|
||||
[[kernel]] void binary_op_g_nd2<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] \
|
||||
[[kernel]] void binary_op_g_nd3<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||
template [[host_name(name "_1")]] [[kernel]] void \
|
||||
binary_op_g_nd1<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] [[kernel]] void \
|
||||
binary_op_g_nd2<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] [[kernel]] void \
|
||||
binary_op_g_nd3<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||
|
||||
|
||||
#define instantiate_binary_g(name, itype, otype, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void binary_op_g<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
#define instantiate_binary_g(name, itype, otype, op) \
|
||||
template [[host_name(name)]] [[kernel]] void binary_op_g<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
||||
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
||||
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
||||
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
||||
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
||||
instantiate_binary_g("g" #name #tname, itype, otype, op) \
|
||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op)
|
||||
instantiate_binary_g("g" #name #tname, itype, otype, op) \
|
||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op) // clang-format on
|
||||
|
||||
#define instantiate_binary_integer(name, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||
// clang-format off
|
||||
#define instantiate_binary_integer(name, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op) // clang-format on
|
||||
|
||||
#define instantiate_binary_float(name, op) \
|
||||
instantiate_binary_all(name, float16, half, half, op) \
|
||||
// clang-format off
|
||||
#define instantiate_binary_float(name, op) \
|
||||
instantiate_binary_all(name, float16, half, half, op) \
|
||||
instantiate_binary_all(name, float32, float, float, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) // clang-format on
|
||||
|
||||
#define instantiate_binary_types(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_integer(name, op) \
|
||||
// clang-format off
|
||||
#define instantiate_binary_types(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_integer(name, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||
instantiate_binary_float(name, op)
|
||||
instantiate_binary_float(name, op) // clang-format on
|
||||
|
||||
#define instantiate_binary_types_bool(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, bool, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, bool, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, bool, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, bool, op) \
|
||||
instantiate_binary_all(name, float16, half, bool, op) \
|
||||
instantiate_binary_all(name, float32, float, bool, op) \
|
||||
// clang-format off
|
||||
#define instantiate_binary_types_bool(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, bool, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, bool, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, bool, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, bool, op) \
|
||||
instantiate_binary_all(name, float16, half, bool, op) \
|
||||
instantiate_binary_all(name, float32, float, bool, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, bool, op)
|
||||
instantiate_binary_all(name, complex64, complex64_t, bool, op) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_binary_types(add, Add)
|
||||
instantiate_binary_types(div, Divide)
|
||||
instantiate_binary_types_bool(eq, Equal)
|
||||
|
@ -253,4 +259,4 @@ instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr)
|
|||
instantiate_binary_integer(bitwise_xor, BitwiseXor)
|
||||
instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor)
|
||||
instantiate_binary_integer(left_shift, LeftShift)
|
||||
instantiate_binary_integer(right_shift, RightShift)
|
||||
instantiate_binary_integer(right_shift, RightShift) // clang-format on
|
||||
|
|
|
@ -3,28 +3,42 @@
|
|||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
struct FloorDivide {
|
||||
template <typename T> T operator()(T x, T y) { return x / y; }
|
||||
template <> float operator()(float x, float y) { return trunc(x / y); }
|
||||
template <> half operator()(half x, half y) { return trunc(x / y); }
|
||||
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return trunc(x / y); }
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
template <>
|
||||
float operator()(float x, float y) {
|
||||
return trunc(x / y);
|
||||
}
|
||||
template <>
|
||||
half operator()(half x, half y) {
|
||||
return trunc(x / y);
|
||||
}
|
||||
template <>
|
||||
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
|
||||
return trunc(x / y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T> operator()(T x, T y) {
|
||||
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
return x % y;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T> operator()(T x, T y) {
|
||||
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
auto r = x % y;
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
return r;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
|
@ -32,10 +46,11 @@ struct Remainder {
|
|||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
return r;
|
||||
}
|
||||
template <> complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x % y;
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x % y;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -50,7 +65,6 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||
d[index] = Op2()(a[0], b[0]);
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_ss(
|
||||
device const T* a,
|
||||
|
@ -139,7 +153,8 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
|
||||
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
@ -156,7 +171,8 @@ template <typename T, typename U, typename Op1, typename Op2, int DIM>
|
|||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
|
||||
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
@ -180,99 +196,102 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||
}
|
||||
|
||||
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void binary_op_##bopt<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name(name)]] [[kernel]] void \
|
||||
binary_op_##bopt<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
|
||||
template [[host_name(name "_" #dims)]] \
|
||||
[[kernel]] void binary_op_g_nd<itype, otype, op1, op2, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||
binary_op_g_nd<itype, otype, op1, op2, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
|
||||
template [[host_name(name "_1")]] \
|
||||
[[kernel]] void binary_op_g_nd1<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] \
|
||||
[[kernel]] void binary_op_g_nd2<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] \
|
||||
[[kernel]] void binary_op_g_nd3<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5)
|
||||
|
||||
template [[host_name(name "_1")]] [[kernel]] void \
|
||||
binary_op_g_nd1<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] [[kernel]] void \
|
||||
binary_op_g_nd2<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] [[kernel]] void \
|
||||
binary_op_g_nd3<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) // clang-format on
|
||||
|
||||
#define instantiate_binary_g(name, itype, otype, op1, op2) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void binary_op_g<itype, otype, op2, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
template [[host_name(name)]] [[kernel]] void \
|
||||
binary_op_g<itype, otype, op2, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
|
||||
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
|
||||
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
|
||||
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
|
||||
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
|
||||
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
|
||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2)
|
||||
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
|
||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) // clang-format on
|
||||
|
||||
#define instantiate_binary_float(name, op1, op2) \
|
||||
instantiate_binary_all(name, float16, half, half, op1, op2) \
|
||||
// clang-format off
|
||||
#define instantiate_binary_float(name, op1, op2) \
|
||||
instantiate_binary_all(name, float16, half, half, op1, op2) \
|
||||
instantiate_binary_all(name, float32, float, float, op1, op2) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2)
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) // clang-format on
|
||||
|
||||
#define instantiate_binary_types(name, op1, op2) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
|
||||
// clang-format off
|
||||
#define instantiate_binary_types(name, op1, op2) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
|
||||
instantiate_binary_float(name, op1, op2)
|
||||
|
||||
instantiate_binary_types(divmod, FloorDivide, Remainder)
|
||||
instantiate_binary_types(divmod, FloorDivide, Remainder) // clang-format on
|
||||
|
|
|
@ -22,7 +22,7 @@ struct complex64_t {
|
|||
float imag;
|
||||
|
||||
// Constructors
|
||||
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
|
||||
constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
|
||||
|
||||
// Conversions to complex64_t
|
||||
template <
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
|
@ -23,14 +21,15 @@ template <typename T, int N>
|
|||
device T* out [[buffer(1)]],
|
||||
const constant MLXConvParams<N>* params [[buffer(2)]],
|
||||
uint3 gid [[thread_position_in_grid]]) {
|
||||
|
||||
int filter_size = params->C;
|
||||
for(short i = 0; i < N; i++) filter_size *= params->wS[i];
|
||||
for (short i = 0; i < N; i++)
|
||||
filter_size *= params->wS[i];
|
||||
|
||||
int out_pixels = 1;
|
||||
for(short i = 0; i < N; i++) out_pixels *= params->oS[i];
|
||||
for (short i = 0; i < N; i++)
|
||||
out_pixels *= params->oS[i];
|
||||
|
||||
// Set out
|
||||
// Set out
|
||||
out += gid.z * filter_size + gid.y * (params->C);
|
||||
|
||||
// Coordinates in input
|
||||
|
@ -46,11 +45,11 @@ template <typename T, int N>
|
|||
|
||||
bool valid = n < params->N;
|
||||
|
||||
// Unroll dimensions
|
||||
// Unroll dimensions
|
||||
for (int i = N - 1; i >= 0; --i) {
|
||||
int os_ = (oS % params->oS[i]);
|
||||
int ws_ = (wS % params->wS[i]);
|
||||
|
||||
|
||||
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
|
||||
|
||||
int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];
|
||||
|
@ -64,10 +63,10 @@ template <typename T, int N>
|
|||
wS /= params->wS[i];
|
||||
}
|
||||
|
||||
if(valid) {
|
||||
if (valid) {
|
||||
size_t in_offset = n * params->in_strides[0];
|
||||
|
||||
for(int i = 0; i < N; ++i) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
in_offset += is[i] * params->in_strides[i + 1];
|
||||
}
|
||||
|
||||
|
@ -85,12 +84,13 @@ template <typename T, int N>
|
|||
device T* out [[buffer(1)]],
|
||||
const constant MLXConvParams<N>* params [[buffer(2)]],
|
||||
uint3 gid [[thread_position_in_grid]]) {
|
||||
|
||||
int filter_size = params->C;
|
||||
for(short i = 0; i < N; i++) filter_size *= params->wS[i];
|
||||
for (short i = 0; i < N; i++)
|
||||
filter_size *= params->wS[i];
|
||||
|
||||
int out_pixels = 1;
|
||||
for(short i = 0; i < N; i++) out_pixels *= params->oS[i];
|
||||
for (short i = 0; i < N; i++)
|
||||
out_pixels *= params->oS[i];
|
||||
|
||||
// Set out
|
||||
out += gid.z * filter_size + gid.x * (filter_size / params->C);
|
||||
|
@ -128,10 +128,10 @@ template <typename T, int N>
|
|||
out += ws_ * params->str[i];
|
||||
}
|
||||
|
||||
if(valid) {
|
||||
if (valid) {
|
||||
size_t in_offset = n * params->in_strides[0];
|
||||
|
||||
for(int i = 0; i < N; ++i) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
in_offset += is[i] * params->in_strides[i + 1];
|
||||
}
|
||||
|
||||
|
@ -141,24 +141,24 @@ template <typename T, int N>
|
|||
}
|
||||
}
|
||||
|
||||
#define instantiate_naive_unfold_nd(name, itype, n) \
|
||||
template [[host_name("naive_unfold_nd_" #name "_" #n)]] \
|
||||
[[kernel]] void naive_unfold_Nd( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device itype* out [[buffer(1)]], \
|
||||
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
||||
uint3 gid [[thread_position_in_grid]]); \
|
||||
template [[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] \
|
||||
[[kernel]] void naive_unfold_transpose_Nd( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device itype* out [[buffer(1)]], \
|
||||
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
||||
uint3 gid [[thread_position_in_grid]]);
|
||||
#define instantiate_naive_unfold_nd(name, itype, n) \
|
||||
template [[host_name("naive_unfold_nd_" #name "_" #n)]] [[kernel]] void \
|
||||
naive_unfold_Nd( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device itype* out [[buffer(1)]], \
|
||||
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
||||
uint3 gid [[thread_position_in_grid]]); \
|
||||
template \
|
||||
[[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] [[kernel]] void \
|
||||
naive_unfold_transpose_Nd( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device itype* out [[buffer(1)]], \
|
||||
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
||||
uint3 gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_naive_unfold_nd_dims(name, itype) \
|
||||
instantiate_naive_unfold_nd(name, itype, 1) \
|
||||
instantiate_naive_unfold_nd(name, itype, 2) \
|
||||
instantiate_naive_unfold_nd(name, itype, 3)
|
||||
#define instantiate_naive_unfold_nd_dims(name, itype) \
|
||||
instantiate_naive_unfold_nd(name, itype, 1) instantiate_naive_unfold_nd( \
|
||||
name, itype, 2) instantiate_naive_unfold_nd(name, itype, 3)
|
||||
|
||||
instantiate_naive_unfold_nd_dims(float32, float);
|
||||
instantiate_naive_unfold_nd_dims(float16, half);
|
||||
|
@ -168,12 +168,13 @@ instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
|
|||
/// Slow and naive conv2d kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const int BC = 16>
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const int BC = 16>
|
||||
[[kernel]] void naive_conv_2d(
|
||||
const device T* in [[buffer(0)]],
|
||||
const device T* wt [[buffer(1)]],
|
||||
|
@ -183,7 +184,6 @@ template <typename T,
|
|||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
(void)simd_gid;
|
||||
(void)simd_lid;
|
||||
|
||||
|
@ -192,80 +192,82 @@ template <typename T,
|
|||
|
||||
int out_o = tid.y * BN * TN + lid.y * TN;
|
||||
int out_hw = tid.x * BM * TM + lid.x * TM;
|
||||
|
||||
|
||||
int out_h[TM];
|
||||
int out_w[TN];
|
||||
|
||||
for(int m = 0; m < TM; ++m) {
|
||||
for (int m = 0; m < TM; ++m) {
|
||||
int mm = (out_hw + m);
|
||||
out_h[m] = mm / params.oS[1];
|
||||
out_w[m] = mm % params.oS[1];
|
||||
}
|
||||
|
||||
|
||||
T in_local[TM];
|
||||
T wt_local[TN];
|
||||
T out_local[TM * TN] = {T(0)};
|
||||
|
||||
for(int h = 0; h < params.wS[0]; ++h) {
|
||||
for(int w = 0; w < params.wS[1]; ++w) {
|
||||
for(int c = 0; c < params.C; ++c) {
|
||||
|
||||
for (int h = 0; h < params.wS[0]; ++h) {
|
||||
for (int w = 0; w < params.wS[1]; ++w) {
|
||||
for (int c = 0; c < params.C; ++c) {
|
||||
// Local in
|
||||
for(int m = 0; m < TM; m++) {
|
||||
for (int m = 0; m < TM; m++) {
|
||||
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
|
||||
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
|
||||
|
||||
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
|
||||
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0);
|
||||
in_local[m] = valid
|
||||
? in[i * params.in_strides[1] + j * params.in_strides[2] + c]
|
||||
: T(0);
|
||||
}
|
||||
|
||||
// Load weight
|
||||
for (int n = 0; n < TN; ++n) {
|
||||
int o = out_o + n;
|
||||
wt_local[n] = o < params.O ? wt[o * params.wt_strides[0] +
|
||||
h * params.wt_strides[1] +
|
||||
w * params.wt_strides[2] + c] : T(0);
|
||||
wt_local[n] = o < params.O
|
||||
? wt[o * params.wt_strides[0] + h * params.wt_strides[1] +
|
||||
w * params.wt_strides[2] + c]
|
||||
: T(0);
|
||||
}
|
||||
|
||||
// Accumulate
|
||||
for(int m = 0; m < TM; ++m) {
|
||||
for(int n = 0; n < TN; ++n) {
|
||||
for (int m = 0; m < TM; ++m) {
|
||||
for (int n = 0; n < TN; ++n) {
|
||||
out_local[m * TN + n] += in_local[m] * wt_local[n];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(int m = 0; m < TM; ++m) {
|
||||
for(int n = 0; n < TN; ++n) {
|
||||
if(out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && (out_o + n) < params.O)
|
||||
out[out_h[m] * params.out_strides[1] +
|
||||
out_w[m] * params.out_strides[2] + out_o + n] = out_local[m * TN + n];
|
||||
for (int m = 0; m < TM; ++m) {
|
||||
for (int n = 0; n < TN; ++n) {
|
||||
if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] &&
|
||||
(out_o + n) < params.O)
|
||||
out[out_h[m] * params.out_strides[1] +
|
||||
out_w[m] * params.out_strides[2] + out_o + n] =
|
||||
out_local[m * TN + n];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Instantiations
|
||||
|
||||
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||
[[kernel]] void naive_conv_2d<itype, bm, bn, tm, tn>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
const device itype* wt [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
naive_conv_2d<itype, bm, bn, tm, tn>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
const device itype* wt [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_naive_conv_2d_blocks(name, itype) \
|
||||
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
|
||||
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
|
||||
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
|
||||
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
|
||||
|
||||
instantiate_naive_conv_2d_blocks(float32, float);
|
||||
instantiate_naive_conv_2d_blocks(float16, half);
|
||||
|
@ -276,9 +278,7 @@ instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
|
|||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int M, int R, int S>
|
||||
struct WinogradTransforms {
|
||||
|
||||
};
|
||||
struct WinogradTransforms {};
|
||||
|
||||
template <>
|
||||
struct WinogradTransforms<6, 3, 8> {
|
||||
|
@ -287,36 +287,36 @@ struct WinogradTransforms<6, 3, 8> {
|
|||
MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1;
|
||||
MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8;
|
||||
MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
||||
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
||||
{ 0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
|
||||
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
|
||||
{ 0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
|
||||
{ 5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
|
||||
{ 0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
|
||||
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
|
||||
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
||||
{1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
||||
{0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
|
||||
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
|
||||
{0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
|
||||
{5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
|
||||
{0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
|
||||
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
|
||||
{0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
||||
};
|
||||
|
||||
MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
||||
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
||||
{ 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
|
||||
{ 1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
|
||||
{ 1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
|
||||
{ 1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
|
||||
{ 1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
|
||||
{ 1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
|
||||
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
||||
{1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
||||
{1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
|
||||
{1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
|
||||
{1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
|
||||
{1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
|
||||
{1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
|
||||
{1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
|
||||
{0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
||||
};
|
||||
|
||||
MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
||||
{ 1.00, 0.00, 0.00},
|
||||
{ -2.0/9.00, -2.0/9.00, -2.0/9.00},
|
||||
{ -2.0/9.00, 2.0/9.00, -2.0/9.00},
|
||||
{ 1.0/90.0, 1.0/45.0, 2.0/45.0},
|
||||
{ 1.0/90.0, -1.0/45.0, 2.0/45.0},
|
||||
{ 32.0/45.0, 16.0/45.0, 8.0/45.0},
|
||||
{ 32.0/45.0, -16.0/45.0, 8.0/45.0},
|
||||
{ 0.00, 0.00, 1.00},
|
||||
{1.00, 0.00, 0.00},
|
||||
{-2.0 / 9.00, -2.0 / 9.00, -2.0 / 9.00},
|
||||
{-2.0 / 9.00, 2.0 / 9.00, -2.0 / 9.00},
|
||||
{1.0 / 90.0, 1.0 / 45.0, 2.0 / 45.0},
|
||||
{1.0 / 90.0, -1.0 / 45.0, 2.0 / 45.0},
|
||||
{32.0 / 45.0, 16.0 / 45.0, 8.0 / 45.0},
|
||||
{32.0 / 45.0, -16.0 / 45.0, 8.0 / 45.0},
|
||||
{0.00, 0.00, 1.00},
|
||||
};
|
||||
};
|
||||
|
||||
|
@ -324,12 +324,9 @@ constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];
|
|||
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
|
||||
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
|
||||
|
||||
template <typename T,
|
||||
int BC = 32,
|
||||
int BO = 4,
|
||||
int M = 6,
|
||||
int R = 3>
|
||||
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform(
|
||||
template <typename T, int BC = 32, int BO = 4, int M = 6, int R = 3>
|
||||
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void
|
||||
winograd_conv_2d_weight_transform(
|
||||
const device T* wt_in [[buffer(0)]],
|
||||
device T* wt_out [[buffer(1)]],
|
||||
const constant int& C [[buffer(2)]],
|
||||
|
@ -337,7 +334,6 @@ template <typename T,
|
|||
uint tid [[threadgroup_position_in_grid]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using WGT = WinogradTransforms<M, R, 8>;
|
||||
|
||||
// Get lane position in simdgroup
|
||||
|
@ -357,35 +353,37 @@ template <typename T,
|
|||
|
||||
// Move to the correct output filter
|
||||
size_t ko = BO * tid + simd_group_id;
|
||||
wt_in += ko * R * R * C;
|
||||
wt_in += ko * R * R * C;
|
||||
|
||||
// wt_out is stored transposed (A x A x C x O)
|
||||