mirror of
https://github.com/ml-explore/mlx.git
synced 2024-09-15 10:04:00 +02:00
parent
934683088e
commit
f20e97b092
2
.github/workflows/pull_request.yml
vendored
2
.github/workflows/pull_request.yml
vendored
|
@ -17,4 +17,4 @@ jobs:
|
|||
pip install pre-commit black isort clang-format
|
||||
- name: Run lint
|
||||
run: |
|
||||
pre-commit run --all-files
|
||||
pre-commit run --all-files
|
||||
|
|
|
@ -206,7 +206,7 @@ void array::ArrayDesc::init() {
|
|||
strides[i] = size;
|
||||
size *= shape[i];
|
||||
}
|
||||
for (auto& in : inputs) {
|
||||
for (const auto& in : inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
}
|
||||
}
|
||||
|
@ -231,7 +231,7 @@ array::ArrayDesc::ArrayDesc(
|
|||
|
||||
array::ArrayDesc::~ArrayDesc() {
|
||||
// When an array description is destroyed it will delete a bunch of arrays
|
||||
// that may also destory their corresponding descriptions and so on and so
|
||||
// that may also destroy their corresponding descriptions and so on and so
|
||||
// forth.
|
||||
//
|
||||
// This calls recursively the destructor and can result in stack overflow, we
|
||||
|
|
55
mlx/array.h
55
mlx/array.h
|
@ -73,32 +73,32 @@ class array {
|
|||
this->array_desc_ = other.array_desc_;
|
||||
}
|
||||
return *this;
|
||||
};
|
||||
}
|
||||
|
||||
/** The size of the array's datatype in bytes. */
|
||||
size_t itemsize() const {
|
||||
return size_of(dtype());
|
||||
};
|
||||
}
|
||||
|
||||
/** The number of elements in the array. */
|
||||
size_t size() const {
|
||||
return array_desc_->size;
|
||||
};
|
||||
}
|
||||
|
||||
/** The number of bytes in the array. */
|
||||
size_t nbytes() const {
|
||||
return size() * itemsize();
|
||||
};
|
||||
}
|
||||
|
||||
/** The number of dimensions of the array. */
|
||||
size_t ndim() const {
|
||||
return array_desc_->shape.size();
|
||||
};
|
||||
}
|
||||
|
||||
/** The shape of the array as a vector of integers. */
|
||||
const std::vector<int>& shape() const {
|
||||
return array_desc_->shape;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the size of the corresponding dimension.
|
||||
|
@ -107,12 +107,12 @@ class array {
|
|||
* bounds checking. */
|
||||
int shape(int dim) const {
|
||||
return shape().at(dim < 0 ? dim + ndim() : dim);
|
||||
};
|
||||
}
|
||||
|
||||
/** The strides of the array. */
|
||||
const std::vector<size_t>& strides() const {
|
||||
return array_desc_->strides;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the stride of the corresponding dimension.
|
||||
|
@ -121,12 +121,12 @@ class array {
|
|||
* bounds checking. */
|
||||
size_t strides(int dim) const {
|
||||
return strides().at(dim < 0 ? dim + ndim() : dim);
|
||||
};
|
||||
}
|
||||
|
||||
/** Get the arrays data type. */
|
||||
Dtype dtype() const {
|
||||
return array_desc_->dtype;
|
||||
};
|
||||
}
|
||||
|
||||
/** Evaluate the array. */
|
||||
void eval();
|
||||
|
@ -160,10 +160,10 @@ class array {
|
|||
|
||||
friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {
|
||||
return a.arr.id() == b.arr.id() && a.idx == b.idx;
|
||||
};
|
||||
}
|
||||
friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {
|
||||
return !(a == b);
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
const array& arr;
|
||||
|
@ -209,7 +209,7 @@ class array {
|
|||
allocator::Buffer buffer;
|
||||
deleter_t d;
|
||||
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
|
||||
: buffer(buffer), d(d) {};
|
||||
: buffer(buffer), d(d) {}
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
Data& operator=(const Data& d) = delete;
|
||||
|
@ -230,22 +230,22 @@ class array {
|
|||
/** The array's primitive. */
|
||||
Primitive& primitive() const {
|
||||
return *(array_desc_->primitive);
|
||||
};
|
||||
}
|
||||
|
||||
/** A shared pointer to the array's primitive. */
|
||||
std::shared_ptr<Primitive>& primitive_ptr() const {
|
||||
return array_desc_->primitive;
|
||||
};
|
||||
}
|
||||
|
||||
/** Check if the array has an attached primitive or is a leaf node. */
|
||||
bool has_primitive() const {
|
||||
return array_desc_->primitive != nullptr;
|
||||
};
|
||||
}
|
||||
|
||||
/** The array's inputs. */
|
||||
const std::vector<array>& inputs() const {
|
||||
return array_desc_->inputs;
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<array>& inputs() {
|
||||
return array_desc_->inputs;
|
||||
|
@ -259,12 +259,12 @@ class array {
|
|||
/** The array's siblings. */
|
||||
const std::vector<array>& siblings() const {
|
||||
return array_desc_->siblings;
|
||||
};
|
||||
}
|
||||
|
||||
/** The array's siblings. */
|
||||
std::vector<array>& siblings() {
|
||||
return array_desc_->siblings;
|
||||
};
|
||||
}
|
||||
|
||||
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
||||
array_desc_->siblings = std::move(siblings);
|
||||
|
@ -281,7 +281,7 @@ class array {
|
|||
outputs.push_back(*this);
|
||||
outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
|
||||
return outputs;
|
||||
};
|
||||
}
|
||||
|
||||
/** Detach the array from the graph. */
|
||||
void detach();
|
||||
|
@ -289,19 +289,19 @@ class array {
|
|||
/** Get the Flags bit-field. */
|
||||
const Flags& flags() const {
|
||||
return array_desc_->flags;
|
||||
};
|
||||
}
|
||||
|
||||
/** The size (in elements) of the underlying buffer the array points to. */
|
||||
size_t data_size() const {
|
||||
return array_desc_->data_size;
|
||||
};
|
||||
}
|
||||
|
||||
allocator::Buffer& buffer() {
|
||||
return array_desc_->data->buffer;
|
||||
};
|
||||
}
|
||||
const allocator::Buffer& buffer() const {
|
||||
return array_desc_->data->buffer;
|
||||
};
|
||||
}
|
||||
|
||||
// Return a copy of the shared pointer
|
||||
// to the array::Data struct
|
||||
|
@ -312,19 +312,20 @@ class array {
|
|||
template <typename T>
|
||||
T* data() {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* data() const {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
};
|
||||
}
|
||||
|
||||
enum Status { unscheduled, scheduled, available };
|
||||
|
||||
bool is_available() const {
|
||||
return status() == Status::available;
|
||||
}
|
||||
const Status status() const {
|
||||
|
||||
Status status() const {
|
||||
return array_desc_->status;
|
||||
}
|
||||
|
||||
|
|
|
@ -123,7 +123,7 @@ struct AccelerateSimdOps {
|
|||
|
||||
VT max(VT a, VT b) {
|
||||
return simd_max(a, b);
|
||||
};
|
||||
}
|
||||
|
||||
VT exp(VT x) {
|
||||
return simd_fast_exp(x);
|
||||
|
@ -170,7 +170,7 @@ struct NeonFp16SimdOps {
|
|||
|
||||
VT max(VT a, VT b) {
|
||||
return vmaxq_f16(a, b);
|
||||
};
|
||||
}
|
||||
|
||||
VT exp(VT x) {
|
||||
return neon_fast_exp(x);
|
||||
|
|
|
@ -108,105 +108,105 @@ struct Abs {
|
|||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::abs(x);
|
||||
};
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::acos(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::acosh(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::asin(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::asinh(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::atan(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return std::atan2(y, x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::atanh(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::ceil(x);
|
||||
};
|
||||
}
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Conjugate {
|
||||
|
@ -219,35 +219,35 @@ struct Cos {
|
|||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::cos(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::cosh(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(fast_erf(static_cast<float>(x)));
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(fast_erfinv(static_cast<float>(x)));
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return fast_exp(x);
|
||||
};
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return std::exp(x);
|
||||
|
@ -258,83 +258,83 @@ struct Expm1 {
|
|||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return expm1(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::floor(x);
|
||||
};
|
||||
}
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::log(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::log2(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::log10(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return log1p(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return !x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return -x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Round {
|
||||
|
@ -379,49 +379,49 @@ struct Sin {
|
|||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::sin(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::sinh(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return x * x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::sqrt(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<decltype(x)>(1.0) / std::sqrt(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::tan(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::tanh(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Add {
|
||||
|
@ -554,7 +554,7 @@ struct LogAddExp {
|
|||
? maxval
|
||||
: static_cast<decltype(x)>(
|
||||
maxval + std::log1p(fast_exp(minval - maxval)));
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
|
@ -602,14 +602,14 @@ struct LogicalAnd {
|
|||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x && y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x || y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Select {
|
||||
|
@ -623,35 +623,35 @@ struct BitwiseAnd {
|
|||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x & y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x | y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x ^ y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x << y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x >> y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
|
|
|
@ -23,7 +23,7 @@ template <typename U = bool>
|
|||
struct And {
|
||||
bool simd_reduce(bool val) {
|
||||
return simd_all(val);
|
||||
};
|
||||
}
|
||||
|
||||
static constexpr constant bool init = true;
|
||||
|
||||
|
@ -61,7 +61,7 @@ template <typename U = bool>
|
|||
struct Or {
|
||||
bool simd_reduce(bool val) {
|
||||
return simd_any(val);
|
||||
};
|
||||
}
|
||||
|
||||
static constexpr constant bool init = false;
|
||||
|
||||
|
@ -100,7 +100,7 @@ struct Sum {
|
|||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
return simd_sum(val);
|
||||
};
|
||||
}
|
||||
|
||||
static constexpr constant U init = U(0);
|
||||
|
||||
|
@ -120,7 +120,7 @@ struct Prod {
|
|||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
return simd_product(val);
|
||||
};
|
||||
}
|
||||
|
||||
static constexpr constant U init = U(1);
|
||||
|
||||
|
@ -140,7 +140,7 @@ struct Min {
|
|||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
return simd_min(val);
|
||||
};
|
||||
}
|
||||
|
||||
static constexpr constant U init = Limits<U>::max;
|
||||
|
||||
|
@ -160,7 +160,7 @@ struct Max {
|
|||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
return simd_max(val);
|
||||
};
|
||||
}
|
||||
|
||||
static constexpr constant U init = Limits<U>::min;
|
||||
|
||||
|
|
|
@ -181,7 +181,7 @@ void merge_one(array& dst, array& src, ParentsMap& parents_map) {
|
|||
}
|
||||
// Remove the source from the map to avoid fusing with it again
|
||||
parents_map.erase(src_parents);
|
||||
};
|
||||
}
|
||||
|
||||
// Helper that merges two arrays in the graph by setting the parents of the
|
||||
// source to point to the destination. The arrays are assumed to be coming from
|
||||
|
@ -194,7 +194,7 @@ void merge(array& dst, array& src, ParentsMap& parents_map) {
|
|||
for (int i = 0; i < sources.size(); ++i) {
|
||||
merge_one(dests[i], sources[i], parents_map);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
template <typename T, typename... U>
|
||||
std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
|
||||
|
@ -260,7 +260,7 @@ class CompilerCache {
|
|||
// Otherwise append a new cache entry
|
||||
entries.push_back(CacheEntry{});
|
||||
return entries.back();
|
||||
};
|
||||
}
|
||||
|
||||
void erase(std::uintptr_t fun_id) {
|
||||
cache_.erase(fun_id);
|
||||
|
|
|
@ -13,7 +13,7 @@ struct Device {
|
|||
static constexpr DeviceType cpu = DeviceType::cpu;
|
||||
static constexpr DeviceType gpu = DeviceType::gpu;
|
||||
|
||||
Device(DeviceType type, int index = 0) : type(type), index(index) {};
|
||||
Device(DeviceType type, int index = 0) : type(type), index(index) {}
|
||||
|
||||
DeviceType type;
|
||||
int index;
|
||||
|
|
|
@ -51,10 +51,10 @@ struct Dtype {
|
|||
|
||||
Val val;
|
||||
const uint8_t size;
|
||||
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size) {};
|
||||
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size) {}
|
||||
constexpr operator Val() const {
|
||||
return val;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
|
||||
|
|
24
mlx/event.h
24
mlx/event.h
|
@ -10,46 +10,46 @@ namespace mlx::core {
|
|||
|
||||
class Event {
|
||||
public:
|
||||
Event() {};
|
||||
Event() = default;
|
||||
|
||||
Event(const Stream& steam);
|
||||
|
||||
// Wait for the event to be signaled at its curent value
|
||||
// Wait for the event to be signaled at its current value
|
||||
void wait();
|
||||
|
||||
// Signal the event at its current value
|
||||
void signal();
|
||||
|
||||
// Check if the event is valid
|
||||
bool valid() {
|
||||
bool valid() const {
|
||||
return event_ != nullptr;
|
||||
};
|
||||
}
|
||||
|
||||
uint64_t value() {
|
||||
uint64_t value() const {
|
||||
return value_;
|
||||
};
|
||||
}
|
||||
|
||||
void set_value(uint64_t v) {
|
||||
value_ = v;
|
||||
};
|
||||
}
|
||||
|
||||
const Stream& stream() {
|
||||
const Stream& stream() const {
|
||||
if (!valid()) {
|
||||
throw std::runtime_error(
|
||||
"[Event::stream] Cannot access stream on invalid event.");
|
||||
}
|
||||
return stream_;
|
||||
};
|
||||
}
|
||||
|
||||
const std::shared_ptr<void>& raw_event() {
|
||||
const std::shared_ptr<void>& raw_event() const {
|
||||
return event_;
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
// Default constructed stream should never be used
|
||||
// since the event is not yet valid
|
||||
Stream stream_{0, Device::cpu};
|
||||
std::shared_ptr<void> event_{nullptr};
|
||||
std::shared_ptr<void> event_;
|
||||
uint64_t value_{0};
|
||||
};
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ class Custom : public Primitive {
|
|||
explicit Custom(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback)
|
||||
: Primitive(stream), fallback_(fallback) {};
|
||||
: Primitive(stream), fallback_(fallback) {}
|
||||
|
||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
|
@ -39,12 +39,12 @@ class RMSNorm : public Custom {
|
|||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps) {};
|
||||
: Custom(stream, fallback), eps_(eps) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
}
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
|
@ -68,12 +68,12 @@ class RMSNormVJP : public Custom {
|
|||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps) {};
|
||||
: Custom(stream, fallback), eps_(eps) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
}
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
|
@ -91,12 +91,12 @@ class LayerNorm : public Custom {
|
|||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps) {};
|
||||
: Custom(stream, fallback), eps_(eps) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
}
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
|
@ -120,12 +120,12 @@ class LayerNormVJP : public Custom {
|
|||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps) {};
|
||||
: Custom(stream, fallback), eps_(eps) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
}
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
|
@ -154,12 +154,12 @@ class RoPE : public Custom {
|
|||
base_(base),
|
||||
scale_(scale),
|
||||
offset_(offset),
|
||||
forward_(forward) {};
|
||||
forward_(forward) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
}
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
|
@ -189,17 +189,17 @@ class ScaledDotProductAttention : public Custom {
|
|||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
const float scale,
|
||||
const bool needs_mask)
|
||||
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {};
|
||||
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
}
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
eval_gpu(inputs, outputs[0]);
|
||||
};
|
||||
}
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out);
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
|
|
@ -116,7 +116,7 @@ std::vector<array> Primitive::jvp(
|
|||
print(msg);
|
||||
msg << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<array> Primitive::vjp(
|
||||
const std::vector<array>&,
|
||||
|
@ -128,7 +128,7 @@ std::vector<array> Primitive::vjp(
|
|||
print(msg);
|
||||
msg << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
|
||||
const std::vector<array>&,
|
||||
|
@ -138,7 +138,7 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
|
|||
print(msg);
|
||||
msg << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> Primitive::output_shapes(
|
||||
const std::vector<array>&) {
|
||||
|
@ -147,7 +147,7 @@ std::vector<std::vector<int>> Primitive::output_shapes(
|
|||
this->print(msg);
|
||||
msg << " cannot infer output shapes.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<array> Abs::vjp(
|
||||
const std::vector<array>& primals,
|
||||
|
@ -3430,7 +3430,7 @@ std::pair<std::vector<array>, std::vector<int>> StopGradient::vmap(
|
|||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
return {{stop_gradient(inputs[0], stream())}, axes};
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<array> Subtract::vjp(
|
||||
const std::vector<array>& primals,
|
||||
|
|
184
mlx/primitives.h
184
mlx/primitives.h
|
@ -40,7 +40,7 @@
|
|||
std::vector<std::vector<int>> output_shapes( \
|
||||
const std::vector<array>& inputs) override { \
|
||||
return {inputs[0].shape()}; \
|
||||
};
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
|
@ -154,7 +154,7 @@ class UnaryPrimitive : public Primitive {
|
|||
|
||||
class Abs : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Abs(Stream stream) : UnaryPrimitive(stream) {};
|
||||
explicit Abs(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -171,7 +171,7 @@ class Abs : public UnaryPrimitive {
|
|||
|
||||
class Add : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Add(Stream stream) : UnaryPrimitive(stream) {};
|
||||
explicit Add(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -189,7 +189,7 @@ class Add : public UnaryPrimitive {
|
|||
class AddMM : public UnaryPrimitive {
|
||||
public:
|
||||
explicit AddMM(Stream stream, float alpha, float beta)
|
||||
: UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
: UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -213,7 +213,7 @@ class AddMM : public UnaryPrimitive {
|
|||
class Arange : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Arange(Stream stream, double start, double stop, double step)
|
||||
: UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {};
|
||||
: UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -231,7 +231,7 @@ class Arange : public UnaryPrimitive {
|
|||
|
||||
class ArcCos : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArcCos(Stream stream) : UnaryPrimitive(stream) {};
|
||||
explicit ArcCos(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -248,7 +248,7 @@ class ArcCos : public UnaryPrimitive {
|
|||
|
||||
class ArcCosh : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArcCosh(Stream stream) : UnaryPrimitive(stream) {};
|
||||
explicit ArcCosh(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -265,7 +265,7 @@ class ArcCosh : public UnaryPrimitive {
|
|||
|
||||
class ArcSin : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArcSin(Stream stream) : UnaryPrimitive(stream) {};
|
||||
explicit ArcSin(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -282,7 +282,7 @@ class ArcSin : public UnaryPrimitive {
|
|||
|
||||
class ArcSinh : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArcSinh(Stream stream) : UnaryPrimitive(stream) {};
|
||||
explicit ArcSinh(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -299,7 +299,7 @@ class ArcSinh : public UnaryPrimitive {
|
|||
|
||||
class ArcTan : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArcTan(Stream stream) : UnaryPrimitive(stream) {};
|
||||
explicit ArcTan(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -316,7 +316,7 @@ class ArcTan : public UnaryPrimitive {
|
|||
|
||||
class ArcTan2 : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArcTan2(Stream stream) : UnaryPrimitive(stream) {};
|
||||
explicit ArcTan2(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -333,7 +333,7 @@ class ArcTan2 : public UnaryPrimitive {
|
|||
|
||||
class ArcTanh : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArcTanh(Stream stream) : UnaryPrimitive(stream) {};
|
||||
explicit ArcTanh(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -351,7 +351,7 @@ class ArcTanh : public UnaryPrimitive {
|
|||
class ArgPartition : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArgPartition(Stream stream, int kth, int axis)
|
||||
: UnaryPrimitive(stream), kth_(kth), axis_(axis) {};
|
||||
: UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -376,7 +376,7 @@ class ArgReduce : public UnaryPrimitive {
|
|||
};
|
||||
|
||||
explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
|
||||
: UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {};
|
||||
: UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -397,7 +397,7 @@ class ArgReduce : public UnaryPrimitive {
|
|||
class ArgSort : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArgSort(Stream stream, int axis)
|
||||
: UnaryPrimitive(stream), axis_(axis) {};
|
||||
: UnaryPrimitive(stream), axis_(axis) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -416,7 +416,7 @@ class ArgSort : public UnaryPrimitive {
|
|||
class AsType : public UnaryPrimitive {
|
||||
public:
|
||||
explicit AsType(Stream stream, Dtype dtype)
|
||||
: UnaryPrimitive(stream), dtype_(dtype) {};
|
||||
: UnaryPrimitive(stream), dtype_(dtype) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -443,7 +443,7 @@ class AsStrided : public UnaryPrimitive {
|
|||
: UnaryPrimitive(stream),
|
||||
shape_(std::move(shape)),
|
||||
strides_(std::move(strides)),
|
||||
offset_(offset) {};
|
||||
offset_(offset) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -465,7 +465,7 @@ class BitwiseBinary : public UnaryPrimitive {
|
|||
enum Op { And, Or, Xor, LeftShift, RightShift };
|
||||
|
||||
explicit BitwiseBinary(Stream stream, Op op)
|
||||
: UnaryPrimitive(stream), op_(op) {};
|
||||
: UnaryPrimitive(stream), op_(op) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -482,7 +482,7 @@ class BitwiseBinary : public UnaryPrimitive {
|
|||
class BlockMaskedMM : public UnaryPrimitive {
|
||||
public:
|
||||
explicit BlockMaskedMM(Stream stream, int block_size)
|
||||
: UnaryPrimitive(stream), block_size_(block_size) {};
|
||||
: UnaryPrimitive(stream), block_size_(block_size) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -504,7 +504,7 @@ class BlockMaskedMM : public UnaryPrimitive {
|
|||
|
||||
class GatherMM : public UnaryPrimitive {
|
||||
public:
|
||||
explicit GatherMM(Stream stream) : UnaryPrimitive(stream) {};
|
||||
explicit GatherMM(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -525,7 +525,7 @@ class GatherMM : public UnaryPrimitive {
|
|||
class Broadcast : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Broadcast(Stream stream, const std::vector<int>& shape)
|
||||
: UnaryPrimitive(stream), shape_(shape) {};
|
||||
: UnaryPrimitive(stream), shape_(shape) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -543,7 +543,7 @@ class Broadcast : public UnaryPrimitive {
|
|||
|
||||
class Ceil : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Ceil(Stream stream) : UnaryPrimitive(stream) {};
|
||||
explicit Ceil(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -604,7 +604,7 @@ class Compiled : public Primitive {
|
|||
class Concatenate : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Concatenate(Stream stream, int axis)
|
||||
: UnaryPrimitive(stream), axis_(axis) {};
|
||||
: UnaryPrimitive(stream), axis_(axis) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -622,7 +622,7 @@ class Concatenate : public UnaryPrimitive {
|
|||
|
||||
class Conjugate : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Conjugate(Stream stream) : UnaryPrimitive(stream) {};
|
||||
explicit Conjugate(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -652,7 +652,7 @@ class Convolution : public UnaryPrimitive {
|
|||
kernel_dilation_(kernel_dilation),
|
||||
input_dilation_(input_dilation),
|
||||
groups_(groups),
|
||||
flip_(flip) {};
|
||||
flip_(flip) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -679,7 +679,7 @@ class Convolution : public UnaryPrimitive {
|
|||
|
||||
class Copy : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Copy(Stream stream) : UnaryPrimitive(stream) {};
|
||||
explicit Copy(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -696,7 +696,7 @@ class Copy : public UnaryPrimitive {
|
|||
|
||||
class Cos : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Cos(Stream stream) : UnaryPrimitive(stream) {};
|
||||
explicit Cos(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -713,7 +713,7 @@ class Cos : public UnaryPrimitive {
|
|||
|
||||
class Cosh : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Cosh(Stream stream) : UnaryPrimitive(stream) {};
|
||||
explicit Cosh(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -784,7 +784,7 @@ class Depends : public Primitive {
|
|||
|
||||
class Divide : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Divide(Stream stream) : UnaryPrimitive(stream) {};
|
||||
explicit Divide(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -801,7 +801,7 @@ class Divide : public UnaryPrimitive {
|
|||
|
||||
class DivMod : public Primitive {
|
||||
public:
|
||||
explicit DivMod(Stream stream) : Primitive(stream) {};
|
||||
explicit DivMod(Stream stream) : Primitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
@ -815,7 +815,7 @@ class DivMod : public Primitive {
|
|||
std::vector<std::vector<int>> output_shapes(
|
||||
const std::vector<array>& inputs) override {
|
||||
return std::vector{inputs[0].shape(), inputs[0].shape()};
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
|
@ -823,7 +823,7 @@ class DivMod : public Primitive {
|
|||
|
||||
class Select : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Select(Stream stream) : UnaryPrimitive(stream) {};
|
||||
explicit Select(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -840,7 +840,7 @@ class Select : public UnaryPrimitive {
|
|||
|
||||
class Remainder : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Remainder(Stream stream) : UnaryPrimitive(stream) {};
|
||||
explicit Remainder(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::ve |