feat: metal formatting and pre-commit bump (#1038)

* feat: metal formatting and pre-commit bump

* add guards

* update

* more guards

* more guards

* smakk fix

* Refactor instantiation of ternary types in ternary.metal

* fix scan.metal
This commit is contained in:
Nripesh Niketan 2024-04-30 18:18:09 +04:00 committed by GitHub
parent 8db7161c94
commit a30e7ed2da
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
45 changed files with 3822 additions and 3337 deletions

View file

@ -1,11 +1,11 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.3
rev: v18.1.4
hooks:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.3.0
rev: 24.4.2
hooks:
- id: black
- repo: https://github.com/pycqa/isort

View file

@ -33,7 +33,7 @@ array axpby(
class Axpby : public Primitive {
public:
explicit Axpby(Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta){};
: Primitive(stream), alpha_(alpha), beta_(beta) {};
/**
* A primitive must know how to evaluate itself on the CPU/GPU

View file

@ -19,7 +19,7 @@ template <typename T>
uint index [[thread_position_in_grid]]) {
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
out[index] =
out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}
@ -31,30 +31,30 @@ template <typename T>
constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
out[index] =
out[index] =
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
}
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] \
[[kernel]] void axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] \
[[kernel]] void axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float);

View file

@ -14,7 +14,7 @@ class Buffer {
void* ptr_;
public:
Buffer(void* ptr) : ptr_(ptr){};
Buffer(void* ptr) : ptr_(ptr) {};
// Get the raw data pointer from the buffer
void* raw_ptr();

View file

@ -209,7 +209,7 @@ class array {
allocator::Buffer buffer;
deleter_t d;
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
: buffer(buffer), d(d){};
: buffer(buffer), d(d) {};
// Not copyable
Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete;

View file

@ -38,7 +38,7 @@ using MTLFCList =
struct CommandEncoder {
CommandEncoder(MTL::ComputeCommandEncoder* enc)
: enc(enc), concurrent(false){};
: enc(enc), concurrent(false) {};
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;

View file

@ -11,22 +11,22 @@ template <typename T>
out[index] = start + index * step;
}
#define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] \
[[kernel]] void arange<type>( \
constant const type& start, \
constant const type& step, \
device type* out, \
uint index [[thread_position_in_grid]]);
#define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
constant const type& start, \
constant const type& step, \
device type* out, \
uint index [[thread_position_in_grid]]);
instantiate_arange(uint8, uint8_t)
// clang-format off
instantiate_arange(uint8, uint8_t)
instantiate_arange(uint16, uint16_t)
instantiate_arange(uint32, uint32_t)
instantiate_arange(uint32, uint32_t)
instantiate_arange(uint64, uint64_t)
instantiate_arange(int8, int8_t)
instantiate_arange(int8, int8_t)
instantiate_arange(int16, int16_t)
instantiate_arange(int32, int32_t)
instantiate_arange(int64, int64_t)
instantiate_arange(float16, half)
instantiate_arange(float32, float)
instantiate_arange(bfloat16, bfloat16_t)
instantiate_arange(bfloat16, bfloat16_t) // clang-format on

View file

@ -18,7 +18,8 @@ struct ArgMin {
static constexpr constant U init = Limits<U>::max;
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
if (best.val > current.val || (best.val == current.val && best.index > current.index)) {
if (best.val > current.val ||
(best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
@ -26,11 +27,12 @@ struct ArgMin {
}
template <int N>
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i=0; i<N; i++) {
IndexValPair<U>
reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] < best.val) {
best.val = vals[i];
best.index = offset+i;
best.index = offset + i;
}
}
return best;
@ -42,7 +44,8 @@ struct ArgMax {
static constexpr constant U init = Limits<U>::min;
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
if (best.val < current.val || (best.val == current.val && best.index > current.index)) {
if (best.val < current.val ||
(best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
@ -50,11 +53,12 @@ struct ArgMax {
}
template <int N>
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i=0; i<N; i++) {
IndexValPair<U>
reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] > best.val) {
best.val = vals[i];
best.index = offset+i;
best.index = offset + i;
}
}
return best;
@ -64,19 +68,16 @@ struct ArgMax {
template <typename U>
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
return IndexValPair<U>{
simd_shuffle_down(data.index, delta),
simd_shuffle_down(data.val, delta)
};
simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)};
}
template <typename T, typename Op, int N_READS>
[[kernel]] void arg_reduce_general(
const device T *in [[buffer(0)]],
device uint32_t *out [[buffer(1)]],
const device int *shape [[buffer(2)]],
const device size_t *in_strides [[buffer(3)]],
const device size_t *out_strides [[buffer(4)]],
const device T* in [[buffer(0)]],
device uint32_t* out [[buffer(1)]],
const device int* shape [[buffer(2)]],
const device size_t* in_strides [[buffer(3)]],
const device size_t* out_strides [[buffer(4)]],
const device size_t& ndim [[buffer(5)]],
const device size_t& axis_stride [[buffer(6)]],
const device size_t& axis_size [[buffer(7)]],
@ -86,7 +87,6 @@ template <typename T, typename Op, int N_READS>
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Shapes and strides *do not* contain the reduction axis. The reduction size
// and stride are provided in axis_stride and axis_size.
//
@ -113,13 +113,13 @@ template <typename T, typename Op, int N_READS>
threadgroup IndexValPair<T> local_data[32];
// Loop over the reduction axis in lsize*N_READS buckets
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
// Read the current value
uint32_t current_index = r*lsize*N_READS + lid*N_READS;
uint32_t current_index = r * lsize * N_READS + lid * N_READS;
uint32_t offset = current_index;
const device T * current_in = in + in_idx + current_index * axis_stride;
const device T* current_in = in + in_idx + current_index * axis_stride;
T vals[N_READS];
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
vals[i] = (current_index < axis_size) ? *current_in : T(Op::init);
current_index++;
current_in += axis_stride;
@ -130,7 +130,7 @@ template <typename T, typename Op, int N_READS>
// need to reduce across the thread group.
// First per simd reduction.
for (uint offset=simd_size/2; offset>0; offset/=2) {
for (uint offset = simd_size / 2; offset > 0; offset /= 2) {
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
best = op.reduce(best, neighbor);
}
@ -149,7 +149,7 @@ template <typename T, typename Op, int N_READS>
if (simd_lane_id < simd_groups) {
best = local_data[simd_lane_id];
}
for (uint offset=simd_size/2; offset>0; offset/=2) {
for (uint offset = simd_size / 2; offset > 0; offset /= 2) {
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
best = op.reduce(best, neighbor);
}
@ -161,24 +161,25 @@ template <typename T, typename Op, int N_READS>
}
#define instantiate_arg_reduce_helper(name, itype, op) \
template [[host_name(name)]] \
[[kernel]] void arg_reduce_general<itype, op<itype>, 4>( \
const device itype *in [[buffer(0)]], \
device uint32_t * out [[buffer(1)]], \
const device int *shape [[buffer(2)]], \
const device size_t *in_strides [[buffer(3)]], \
const device size_t *out_strides [[buffer(4)]], \
const device size_t& ndim [[buffer(5)]], \
const device size_t& axis_stride [[buffer(6)]], \
const device size_t& axis_size [[buffer(7)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
template [[host_name(name)]] [[kernel]] void \
arg_reduce_general<itype, op<itype>, 4>( \
const device itype* in [[buffer(0)]], \
device uint32_t* out [[buffer(1)]], \
const device int* shape [[buffer(2)]], \
const device size_t* in_strides [[buffer(3)]], \
const device size_t* out_strides [[buffer(4)]], \
const device size_t& ndim [[buffer(5)]], \
const device size_t& axis_stride [[buffer(6)]], \
const device size_t& axis_size [[buffer(7)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_arg_reduce(name, itype) \
// clang-format off
#define instantiate_arg_reduce(name, itype) \
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
@ -193,4 +194,4 @@ instantiate_arg_reduce(int32, int32_t)
instantiate_arg_reduce(int64, int64_t)
instantiate_arg_reduce(float16, half)
instantiate_arg_reduce(float32, float)
instantiate_arg_reduce(bfloat16, bfloat16_t)
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on

View file

@ -77,7 +77,8 @@ template <typename T, typename U, typename Op>
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
@ -92,7 +93,8 @@ template <typename T, typename U, typename Op, int DIM>
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
@ -112,114 +114,118 @@ template <typename T, typename U, typename Op>
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
#define instantiate_binary(name, itype, otype, op, bopt) \
template [[host_name(name)]] \
[[kernel]] void binary_op_##bopt<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary(name, itype, otype, op, bopt) \
template \
[[host_name(name)]] [[kernel]] void binary_op_##bopt<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
template [[host_name(name "_" #dims)]] \
[[kernel]] void binary_op_g_nd<itype, otype, op, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
template [[host_name(name "_" #dims)]] [[kernel]] void \
binary_op_g_nd<itype, otype, op, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_g_nd(name, itype, otype, op) \
template [[host_name(name "_1")]] \
[[kernel]] void binary_op_g_nd1<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \
[[kernel]] void binary_op_g_nd2<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \
[[kernel]] void binary_op_g_nd3<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op, 5)
template [[host_name(name "_1")]] [[kernel]] void \
binary_op_g_nd1<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void \
binary_op_g_nd2<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void \
binary_op_g_nd3<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op, 5)
#define instantiate_binary_g(name, itype, otype, op) \
template [[host_name(name)]] \
[[kernel]] void binary_op_g<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
#define instantiate_binary_g(name, itype, otype, op) \
template [[host_name(name)]] [[kernel]] void binary_op_g<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_all(name, tname, itype, otype, op) \
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
instantiate_binary_g("g" #name #tname, itype, otype, op) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op)
instantiate_binary_g("g" #name #tname, itype, otype, op) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op) // clang-format on
#define instantiate_binary_integer(name, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
// clang-format off
#define instantiate_binary_integer(name, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op) // clang-format on
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
// clang-format off
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) // clang-format on
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_integer(name, op) \
// clang-format off
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_integer(name, op) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
instantiate_binary_float(name, op)
instantiate_binary_float(name, op) // clang-format on
#define instantiate_binary_types_bool(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
instantiate_binary_all(name, int8, int8_t, bool, op) \
instantiate_binary_all(name, int16, int16_t, bool, op) \
instantiate_binary_all(name, int32, int32_t, bool, op) \
instantiate_binary_all(name, int64, int64_t, bool, op) \
instantiate_binary_all(name, float16, half, bool, op) \
instantiate_binary_all(name, float32, float, bool, op) \
// clang-format off
#define instantiate_binary_types_bool(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
instantiate_binary_all(name, int8, int8_t, bool, op) \
instantiate_binary_all(name, int16, int16_t, bool, op) \
instantiate_binary_all(name, int32, int32_t, bool, op) \
instantiate_binary_all(name, int64, int64_t, bool, op) \
instantiate_binary_all(name, float16, half, bool, op) \
instantiate_binary_all(name, float32, float, bool, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
instantiate_binary_all(name, complex64, complex64_t, bool, op)
instantiate_binary_all(name, complex64, complex64_t, bool, op) // clang-format on
// clang-format off
instantiate_binary_types(add, Add)
instantiate_binary_types(div, Divide)
instantiate_binary_types_bool(eq, Equal)
@ -253,4 +259,4 @@ instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr)
instantiate_binary_integer(bitwise_xor, BitwiseXor)
instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor)
instantiate_binary_integer(left_shift, LeftShift)
instantiate_binary_integer(right_shift, RightShift)
instantiate_binary_integer(right_shift, RightShift) // clang-format on

View file

@ -3,28 +3,42 @@
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
struct FloorDivide {
template <typename T> T operator()(T x, T y) { return x / y; }
template <> float operator()(float x, float y) { return trunc(x / y); }
template <> half operator()(half x, half y) { return trunc(x / y); }
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return trunc(x / y); }
template <typename T>
T operator()(T x, T y) {
return x / y;
}
template <>
float operator()(float x, float y) {
return trunc(x / y);
}
template <>
half operator()(half x, half y) {
return trunc(x / y);
}
template <>
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
return trunc(x / y);
}
};
struct Remainder {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T> operator()(T x, T y) {
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
operator()(T x, T y) {
return x % y;
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T> operator()(T x, T y) {
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
@ -32,10 +46,11 @@ struct Remainder {
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
return r;
}
template <> complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
};
@ -50,7 +65,6 @@ template <typename T, typename U, typename Op1, typename Op2>
d[index] = Op2()(a[0], b[0]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_ss(
device const T* a,
@ -139,7 +153,8 @@ template <typename T, typename U, typename Op1, typename Op2>
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
}
@ -156,7 +171,8 @@ template <typename T, typename U, typename Op1, typename Op2, int DIM>
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
}
@ -180,99 +196,102 @@ template <typename T, typename U, typename Op1, typename Op2>
}
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \
template [[host_name(name)]] \
[[kernel]] void binary_op_##bopt<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
uint index [[thread_position_in_grid]]);
template [[host_name(name)]] [[kernel]] void \
binary_op_##bopt<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
template [[host_name(name "_" #dims)]] \
[[kernel]] void binary_op_g_nd<itype, otype, op1, op2, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
template [[host_name(name "_" #dims)]] [[kernel]] void \
binary_op_g_nd<itype, otype, op1, op2, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
template [[host_name(name "_1")]] \
[[kernel]] void binary_op_g_nd1<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \
[[kernel]] void binary_op_g_nd2<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \
[[kernel]] void binary_op_g_nd3<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5)
template [[host_name(name "_1")]] [[kernel]] void \
binary_op_g_nd1<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void \
binary_op_g_nd2<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void \
binary_op_g_nd3<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) // clang-format on
#define instantiate_binary_g(name, itype, otype, op1, op2) \
template [[host_name(name)]] \
[[kernel]] void binary_op_g<itype, otype, op2, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
template [[host_name(name)]] [[kernel]] void \
binary_op_g<itype, otype, op2, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2)
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) // clang-format on
#define instantiate_binary_float(name, op1, op2) \
instantiate_binary_all(name, float16, half, half, op1, op2) \
// clang-format off
#define instantiate_binary_float(name, op1, op2) \
instantiate_binary_all(name, float16, half, half, op1, op2) \
instantiate_binary_all(name, float32, float, float, op1, op2) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2)
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) // clang-format on
#define instantiate_binary_types(name, op1, op2) \
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
// clang-format off
#define instantiate_binary_types(name, op1, op2) \
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
instantiate_binary_float(name, op1, op2)
instantiate_binary_types(divmod, FloorDivide, Remainder)
instantiate_binary_types(divmod, FloorDivide, Remainder) // clang-format on

View file

@ -22,7 +22,7 @@ struct complex64_t {
float imag;
// Constructors
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
// Conversions to complex64_t
template <

View file

@ -1,13 +1,11 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#define MLX_MTL_CONST static constant constexpr const
@ -23,14 +21,15 @@ template <typename T, int N>
device T* out [[buffer(1)]],
const constant MLXConvParams<N>* params [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
int filter_size = params->C;
for(short i = 0; i < N; i++) filter_size *= params->wS[i];
for (short i = 0; i < N; i++)
filter_size *= params->wS[i];
int out_pixels = 1;
for(short i = 0; i < N; i++) out_pixels *= params->oS[i];
for (short i = 0; i < N; i++)
out_pixels *= params->oS[i];
// Set out
// Set out
out += gid.z * filter_size + gid.y * (params->C);
// Coordinates in input
@ -46,11 +45,11 @@ template <typename T, int N>
bool valid = n < params->N;
// Unroll dimensions
// Unroll dimensions
for (int i = N - 1; i >= 0; --i) {
int os_ = (oS % params->oS[i]);
int ws_ = (wS % params->wS[i]);
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];
@ -64,10 +63,10 @@ template <typename T, int N>
wS /= params->wS[i];
}
if(valid) {
if (valid) {
size_t in_offset = n * params->in_strides[0];
for(int i = 0; i < N; ++i) {
for (int i = 0; i < N; ++i) {
in_offset += is[i] * params->in_strides[i + 1];
}
@ -85,12 +84,13 @@ template <typename T, int N>
device T* out [[buffer(1)]],
const constant MLXConvParams<N>* params [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
int filter_size = params->C;
for(short i = 0; i < N; i++) filter_size *= params->wS[i];
for (short i = 0; i < N; i++)
filter_size *= params->wS[i];
int out_pixels = 1;
for(short i = 0; i < N; i++) out_pixels *= params->oS[i];
for (short i = 0; i < N; i++)
out_pixels *= params->oS[i];
// Set out
out += gid.z * filter_size + gid.x * (filter_size / params->C);
@ -128,10 +128,10 @@ template <typename T, int N>
out += ws_ * params->str[i];
}
if(valid) {
if (valid) {
size_t in_offset = n * params->in_strides[0];
for(int i = 0; i < N; ++i) {
for (int i = 0; i < N; ++i) {
in_offset += is[i] * params->in_strides[i + 1];
}
@ -141,24 +141,24 @@ template <typename T, int N>
}
}
#define instantiate_naive_unfold_nd(name, itype, n) \
template [[host_name("naive_unfold_nd_" #name "_" #n)]] \
[[kernel]] void naive_unfold_Nd( \
const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \
uint3 gid [[thread_position_in_grid]]); \
template [[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] \
[[kernel]] void naive_unfold_transpose_Nd( \
const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \
uint3 gid [[thread_position_in_grid]]);
#define instantiate_naive_unfold_nd(name, itype, n) \
template [[host_name("naive_unfold_nd_" #name "_" #n)]] [[kernel]] void \
naive_unfold_Nd( \
const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \
uint3 gid [[thread_position_in_grid]]); \
template \
[[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] [[kernel]] void \
naive_unfold_transpose_Nd( \
const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \
uint3 gid [[thread_position_in_grid]]);
#define instantiate_naive_unfold_nd_dims(name, itype) \
instantiate_naive_unfold_nd(name, itype, 1) \
instantiate_naive_unfold_nd(name, itype, 2) \
instantiate_naive_unfold_nd(name, itype, 3)
#define instantiate_naive_unfold_nd_dims(name, itype) \
instantiate_naive_unfold_nd(name, itype, 1) instantiate_naive_unfold_nd( \
name, itype, 2) instantiate_naive_unfold_nd(name, itype, 3)
instantiate_naive_unfold_nd_dims(float32, float);
instantiate_naive_unfold_nd_dims(float16, half);
@ -168,12 +168,13 @@ instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
/// Slow and naive conv2d kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const int BC = 16>
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const int BC = 16>
[[kernel]] void naive_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
@ -183,7 +184,6 @@ template <typename T,
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)simd_gid;
(void)simd_lid;
@ -192,80 +192,82 @@ template <typename T,
int out_o = tid.y * BN * TN + lid.y * TN;
int out_hw = tid.x * BM * TM + lid.x * TM;
int out_h[TM];
int out_w[TN];
for(int m = 0; m < TM; ++m) {
for (int m = 0; m < TM; ++m) {
int mm = (out_hw + m);
out_h[m] = mm / params.oS[1];
out_w[m] = mm % params.oS[1];
}
T in_local[TM];
T wt_local[TN];
T out_local[TM * TN] = {T(0)};
for(int h = 0; h < params.wS[0]; ++h) {
for(int w = 0; w < params.wS[1]; ++w) {
for(int c = 0; c < params.C; ++c) {
for (int h = 0; h < params.wS[0]; ++h) {
for (int w = 0; w < params.wS[1]; ++w) {
for (int c = 0; c < params.C; ++c) {
// Local in
for(int m = 0; m < TM; m++) {
for (int m = 0; m < TM; m++) {
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0);
in_local[m] = valid
? in[i * params.in_strides[1] + j * params.in_strides[2] + c]
: T(0);
}
// Load weight
for (int n = 0; n < TN; ++n) {
int o = out_o + n;
wt_local[n] = o < params.O ? wt[o * params.wt_strides[0] +
h * params.wt_strides[1] +
w * params.wt_strides[2] + c] : T(0);
wt_local[n] = o < params.O
? wt[o * params.wt_strides[0] + h * params.wt_strides[1] +
w * params.wt_strides[2] + c]
: T(0);
}
// Accumulate
for(int m = 0; m < TM; ++m) {
for(int n = 0; n < TN; ++n) {
for (int m = 0; m < TM; ++m) {
for (int n = 0; n < TN; ++n) {
out_local[m * TN + n] += in_local[m] * wt_local[n];
}
}
}
}
}
for(int m = 0; m < TM; ++m) {
for(int n = 0; n < TN; ++n) {
if(out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && (out_o + n) < params.O)
out[out_h[m] * params.out_strides[1] +
out_w[m] * params.out_strides[2] + out_o + n] = out_local[m * TN + n];
for (int m = 0; m < TM; ++m) {
for (int n = 0; n < TN; ++n) {
if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] &&
(out_o + n) < params.O)
out[out_h[m] * params.out_strides[1] +
out_w[m] * params.out_strides[2] + out_o + n] =
out_local[m * TN + n];
}
}
}
// Instantiations
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
[[kernel]] void naive_conv_2d<itype, bm, bn, tm, tn>( \
const device itype* in [[buffer(0)]], \
const device itype* wt [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant MLXConvParams<2>& params [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \
"_tn" #tn)]] [[kernel]] void \
naive_conv_2d<itype, bm, bn, tm, tn>( \
const device itype* in [[buffer(0)]], \
const device itype* wt [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant MLXConvParams<2>& params [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_naive_conv_2d_blocks(name, itype) \
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
instantiate_naive_conv_2d_blocks(float32, float);
instantiate_naive_conv_2d_blocks(float16, half);
@ -276,9 +278,7 @@ instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
template <int M, int R, int S>
struct WinogradTransforms {
};
struct WinogradTransforms {};
template <>
struct WinogradTransforms<6, 3, 8> {
@ -287,36 +287,36 @@ struct WinogradTransforms<6, 3, 8> {
MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1;
MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8;
MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
{ 0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
{ 0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
{ 5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
{ 0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
{1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
{0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
{0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
{5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
{0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
{0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
};
MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
{ 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
{ 1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
{ 1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
{ 1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
{ 1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
{ 1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
{1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
{1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
{1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
{1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
{1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
{1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
{1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
{0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
};
MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
{ 1.00, 0.00, 0.00},
{ -2.0/9.00, -2.0/9.00, -2.0/9.00},
{ -2.0/9.00, 2.0/9.00, -2.0/9.00},
{ 1.0/90.0, 1.0/45.0, 2.0/45.0},
{ 1.0/90.0, -1.0/45.0, 2.0/45.0},
{ 32.0/45.0, 16.0/45.0, 8.0/45.0},
{ 32.0/45.0, -16.0/45.0, 8.0/45.0},
{ 0.00, 0.00, 1.00},
{1.00, 0.00, 0.00},
{-2.0 / 9.00, -2.0 / 9.00, -2.0 / 9.00},
{-2.0 / 9.00, 2.0 / 9.00, -2.0 / 9.00},
{1.0 / 90.0, 1.0 / 45.0, 2.0 / 45.0},
{1.0 / 90.0, -1.0 / 45.0, 2.0 / 45.0},
{32.0 / 45.0, 16.0 / 45.0, 8.0 / 45.0},
{32.0 / 45.0, -16.0 / 45.0, 8.0 / 45.0},
{0.00, 0.00, 1.00},
};
};
@ -324,12 +324,9 @@ constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
template <typename T,
int BC = 32,
int BO = 4,
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform(
template <typename T, int BC = 32, int BO = 4, int M = 6, int R = 3>
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void
winograd_conv_2d_weight_transform(
const device T* wt_in [[buffer(0)]],
device T* wt_out [[buffer(1)]],
const constant int& C [[buffer(2)]],
@ -337,7 +334,6 @@ template <typename T,
uint tid [[threadgroup_position_in_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
using WGT = WinogradTransforms<M, R, 8>;
// Get lane position in simdgroup
@ -357,35 +353,37 @@ template <typename T,
// Move to the correct output filter
size_t ko = BO * tid + simd_group_id;
wt_in += ko * R * R * C;
wt_in += ko * R * R * C;
// wt_out is stored transposed (A x A x C x O)