mirror of
https://github.com/ml-explore/mlx.git
synced 2024-09-15 10:04:00 +02:00
awni's commit files
This commit is contained in:
parent
e411fcae68
commit
8ca7f9e8e9
87
.clang-format
Normal file
87
.clang-format
Normal file
|
@ -0,0 +1,87 @@
|
|||
---
|
||||
AccessModifierOffset: -1
|
||||
AlignAfterOpenBracket: AlwaysBreak
|
||||
AlignConsecutiveAssignments: false
|
||||
AlignConsecutiveDeclarations: false
|
||||
AlignEscapedNewlinesLeft: true
|
||||
AlignOperands: false
|
||||
AlignTrailingComments: false
|
||||
AllowAllParametersOfDeclarationOnNextLine: false
|
||||
AllowShortBlocksOnASingleLine: false
|
||||
AllowShortCaseLabelsOnASingleLine: false
|
||||
AllowShortFunctionsOnASingleLine: Empty
|
||||
AllowShortIfStatementsOnASingleLine: false
|
||||
AllowShortLoopsOnASingleLine: false
|
||||
AlwaysBreakAfterReturnType: None
|
||||
AlwaysBreakBeforeMultilineStrings: true
|
||||
AlwaysBreakTemplateDeclarations: true
|
||||
BinPackArguments: false
|
||||
BinPackParameters: false
|
||||
BraceWrapping:
|
||||
AfterClass: false
|
||||
AfterControlStatement: false
|
||||
AfterEnum: false
|
||||
AfterFunction: false
|
||||
AfterNamespace: false
|
||||
AfterObjCDeclaration: false
|
||||
AfterStruct: false
|
||||
AfterUnion: false
|
||||
BeforeCatch: false
|
||||
BeforeElse: false
|
||||
IndentBraces: false
|
||||
BreakBeforeBinaryOperators: None
|
||||
BreakBeforeBraces: Attach
|
||||
BreakBeforeTernaryOperators: true
|
||||
BreakConstructorInitializersBeforeComma: false
|
||||
BreakAfterJavaFieldAnnotations: false
|
||||
BreakStringLiterals: false
|
||||
ColumnLimit: 80
|
||||
CommentPragmas: '^ IWYU pragma:'
|
||||
ConstructorInitializerAllOnOneLineOrOnePerLine: true
|
||||
ConstructorInitializerIndentWidth: 4
|
||||
ContinuationIndentWidth: 4
|
||||
Cpp11BracedListStyle: true
|
||||
DerivePointerAlignment: false
|
||||
DisableFormat: false
|
||||
ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ]
|
||||
IncludeCategories:
|
||||
- Regex: '^<.*\.h(pp)?>'
|
||||
Priority: 1
|
||||
- Regex: '^<.*'
|
||||
Priority: 2
|
||||
- Regex: '.*'
|
||||
Priority: 3
|
||||
IndentCaseLabels: true
|
||||
IndentWidth: 2
|
||||
IndentWrappedFunctionNames: false
|
||||
KeepEmptyLinesAtTheStartOfBlocks: false
|
||||
MacroBlockBegin: ''
|
||||
MacroBlockEnd: ''
|
||||
MaxEmptyLinesToKeep: 1
|
||||
NamespaceIndentation: None
|
||||
ObjCBlockIndentWidth: 2
|
||||
ObjCSpaceAfterProperty: false
|
||||
ObjCSpaceBeforeProtocolList: false
|
||||
PenaltyBreakBeforeFirstCallParameter: 1
|
||||
PenaltyBreakComment: 300
|
||||
PenaltyBreakFirstLessLess: 120
|
||||
PenaltyBreakString: 1000
|
||||
PenaltyExcessCharacter: 1000000
|
||||
PenaltyReturnTypeOnItsOwnLine: 200
|
||||
PointerAlignment: Left
|
||||
ReflowComments: true
|
||||
SortIncludes: true
|
||||
SpaceAfterCStyleCast: false
|
||||
SpaceBeforeAssignmentOperators: true
|
||||
SpaceBeforeParens: ControlStatements
|
||||
SpaceInEmptyParentheses: false
|
||||
SpacesBeforeTrailingComments: 1
|
||||
SpacesInAngles: false
|
||||
SpacesInContainerLiterals: true
|
||||
SpacesInCStyleCastParentheses: false
|
||||
SpacesInParentheses: false
|
||||
SpacesInSquareBrackets: false
|
||||
Standard: Cpp11
|
||||
TabWidth: 8
|
||||
UseTab: Never
|
||||
...
|
3
MANIFEST.in
Normal file
3
MANIFEST.in
Normal file
|
@ -0,0 +1,3 @@
|
|||
include CMakeLists.txt
|
||||
recursive-include mlx/ *
|
||||
include python/src/*
|
198
benchmarks/cpp/irregular_strides.cpp
Normal file
198
benchmarks/cpp/irregular_strides.cpp
Normal file
|
@ -0,0 +1,198 @@
|
|||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void time_irregular_binary_ops_1D() {
|
||||
auto device = default_device();
|
||||
int size = 1000000;
|
||||
int step = 2;
|
||||
auto a = random::uniform({size});
|
||||
auto b = random::uniform({size});
|
||||
eval(a, b);
|
||||
a = slice(a, {0}, {size}, {step});
|
||||
b = slice(b, {0}, {size}, {step});
|
||||
TIMEM("1D strided", add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_2D() {
|
||||
auto device = default_device();
|
||||
int size = 2048;
|
||||
auto a = random::uniform({size, size});
|
||||
auto b = random::uniform({size, size});
|
||||
eval(a, b);
|
||||
TIMEM("2D regular", add, a, b, device);
|
||||
|
||||
b = transpose(b);
|
||||
eval(b);
|
||||
TIMEM("2D transpose", add, a, b, device);
|
||||
|
||||
b = random::uniform({size});
|
||||
eval(b);
|
||||
TIMEM("2D broadcast dim 0", add, a, b, device);
|
||||
|
||||
b = reshape(b, {size, 1});
|
||||
eval(b);
|
||||
TIMEM("2D broadcast dim 1", add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_3D() {
|
||||
auto device = default_device();
|
||||
int d0 = 32;
|
||||
int d1 = 512;
|
||||
int d2 = 512;
|
||||
auto a = random::uniform({d0, d1, d2});
|
||||
auto b = random::uniform({d0, d1, d2});
|
||||
TIMEM("3D regular", add, a, b, device);
|
||||
|
||||
b = transpose(b, {0, 2, 1});
|
||||
TIMEM("3D transpose", add, a, b, device);
|
||||
|
||||
b = random::uniform({d1, d2});
|
||||
TIMEM("3D broadcast dim 0", add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, 1, d2});
|
||||
TIMEM("3D broadcast dim 1", add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, d1, 1});
|
||||
TIMEM("3D broadcast dim 2", add, a, b, device);
|
||||
|
||||
b = random::uniform({d2});
|
||||
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
|
||||
|
||||
b = random::uniform({d1, 1});
|
||||
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, 1, 1});
|
||||
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_4D() {
|
||||
auto device = default_device();
|
||||
std::vector<int> shape = {8, 8, 512, 512};
|
||||
auto a = random::uniform(shape);
|
||||
auto b = random::uniform(shape);
|
||||
|
||||
TIMEM("4D regular", add, a, b, device);
|
||||
|
||||
b = transpose(b, {0, 1, 3, 2});
|
||||
TIMEM("4D transpose", add, a, b, device);
|
||||
|
||||
std::string om = "4D broadcast dims ";
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = 1;
|
||||
b = random::uniform(shape);
|
||||
std::ostringstream msg;
|
||||
msg << om << i;
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
|
||||
for (int j = i + 1; j < shape.size(); ++j) {
|
||||
shape[j] = 1;
|
||||
std::ostringstream msg;
|
||||
msg << om << i << ", " << j;
|
||||
b = random::uniform(shape);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
shape[j] = a.shape(j);
|
||||
|
||||
for (int k = j + 1; k < shape.size(); ++k) {
|
||||
shape[k] = 1;
|
||||
std::ostringstream msg;
|
||||
msg << om << i << ", " << j << ", " << k;
|
||||
b = random::uniform(shape);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
shape[k] = a.shape(k);
|
||||
}
|
||||
}
|
||||
shape[i] = a.shape(i);
|
||||
}
|
||||
}
|
||||
|
||||
void time_irregular_reshape() {
|
||||
auto device = default_device();
|
||||
std::vector<int> shape;
|
||||
auto reshape_fn = [&shape, device](const array& a) {
|
||||
return reshape(a, shape, device);
|
||||
};
|
||||
|
||||
int size = 64;
|
||||
int d = 2 * size;
|
||||
|
||||
auto a = random::uniform({d, d, d});
|
||||
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D contiguous", reshape_fn, a);
|
||||
|
||||
a = transpose(a);
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D transpose", reshape_fn, a);
|
||||
|
||||
a = transpose(a, {1, 2, 0});
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D transpose dims 1 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, d}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 0", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 1", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
|
||||
}
|
||||
|
||||
void time_irregular_astype_1D() {
|
||||
auto device = default_device();
|
||||
int size = 1000000;
|
||||
int step = 2;
|
||||
auto a = random::uniform({size});
|
||||
a = slice(a, {0}, {size}, {step});
|
||||
TIMEM("1D strided", astype, a, int32, device);
|
||||
}
|
||||
|
||||
void time_irregular_astype_2D() {
|
||||
auto device = default_device();
|
||||
int size = 2048;
|
||||
std::vector<int> shape = {size, size};
|
||||
|
||||
auto a = random::uniform(shape);
|
||||
TIMEM("2D regular", astype, a, int32, device);
|
||||
|
||||
a = transpose(a);
|
||||
TIMEM("2D transpose", astype, a, int32, device);
|
||||
|
||||
a = broadcast_to(random::uniform({size}), shape);
|
||||
TIMEM("2D broadcast dim 0", astype, a, int32, device);
|
||||
|
||||
a = broadcast_to(random::uniform({size, 1}), shape);
|
||||
TIMEM("2D broadcast dim 1", astype, a, int32, device);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc > 1) {
|
||||
bool use_gpu = !strcmp(argv[1], "gpu");
|
||||
set_default_device(use_gpu ? Device::gpu : Device::cpu);
|
||||
}
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
time_irregular_binary_ops_1D();
|
||||
time_irregular_binary_ops_2D();
|
||||
time_irregular_binary_ops_3D();
|
||||
time_irregular_binary_ops_4D();
|
||||
time_irregular_reshape();
|
||||
time_irregular_astype_1D();
|
||||
time_irregular_astype_2D();
|
||||
}
|
247
benchmarks/cpp/single_ops.cpp
Normal file
247
benchmarks/cpp/single_ops.cpp
Normal file
|
@ -0,0 +1,247 @@
|
|||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void time_creation_ops() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto shape = {M, N};
|
||||
auto full_fp32 = [&]() { return full(shape, 3.3f); };
|
||||
TIME(full_fp32);
|
||||
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
|
||||
TIME(zeros_fp32);
|
||||
auto ones_fp32 = [&]() { return ones(shape, float32); };
|
||||
TIME(ones_fp32);
|
||||
|
||||
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
|
||||
TIME(arange_fp32);
|
||||
}
|
||||
|
||||
void time_type_conversions() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto shape = {M, N};
|
||||
auto device = default_device();
|
||||
|
||||
auto a = zeros(shape, float32);
|
||||
eval(a);
|
||||
TIMEM("float32 to int32", astype, a, int32, device);
|
||||
TIMEM("float32 to uint32", astype, a, uint32, device);
|
||||
|
||||
a = zeros(shape, int32);
|
||||
eval(a);
|
||||
TIMEM("int32 to float32", astype, a, float32, device);
|
||||
|
||||
a = zeros(shape, bool_);
|
||||
eval(a);
|
||||
TIMEM("bool to float32", astype, a, float32, device);
|
||||
TIMEM("bool to int32", astype, a, int32, device);
|
||||
TIMEM("bool to uint32", astype, a, uint32, device);
|
||||
}
|
||||
|
||||
void time_random_generation() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
|
||||
auto uniform = [&]() { return random::uniform({M, N}, float32); };
|
||||
TIME(uniform);
|
||||
auto normal = [&]() { return random::normal({M, N}, float32); };
|
||||
TIME(normal);
|
||||
}
|
||||
|
||||
void time_unary_ops() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto device = default_device();
|
||||
|
||||
auto a = random::normal({M, N});
|
||||
eval(a);
|
||||
TIME(mlx::core::abs, a, device);
|
||||
TIME(negative, a, device);
|
||||
TIME(sign, a, device);
|
||||
TIME(square, a, device);
|
||||
TIME(mlx::core::sqrt, a, device);
|
||||
TIME(rsqrt, a, device);
|
||||
TIME(mlx::core::exp, a, device);
|
||||
|
||||
a = random::uniform({M, N});
|
||||
TIME(mlx::core::log, a, device);
|
||||
}
|
||||
|
||||
void time_binary_ops() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
|
||||
TIME(add, a, b, device);
|
||||
TIME(subtract, a, b, device);
|
||||
TIME(multiply, a, b, device);
|
||||
TIME(divide, a, b, device);
|
||||
TIME(maximum, a, b, device);
|
||||
TIME(minimum, a, b, device);
|
||||
|
||||
b = random::uniform({1});
|
||||
eval(b);
|
||||
TIMEM("scalar", add, a, b, device);
|
||||
TIMEM("vector-scalar", subtract, a, b, device);
|
||||
TIMEM("scalar-vector", subtract, b, a, device);
|
||||
TIMEM("scalar", multiply, a, b, device);
|
||||
TIMEM("vector-scalar", divide, a, b, device);
|
||||
TIMEM("scalar-vector", divide, b, a, device);
|
||||
|
||||
a = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
b = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
eval(a, b);
|
||||
TIMEM("scalar-scalar broadcast", add, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", divide, a, b, device);
|
||||
}
|
||||
|
||||
void time_strided_ops() {
|
||||
int M = 50, N = 50, O = 50, P = 50;
|
||||
auto a = random::uniform({M, N, O, P});
|
||||
auto b = random::uniform({M, N, O, P});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIMEM("non-strided", add, a, b, device);
|
||||
a = transpose(a, {1, 0, 2, 3});
|
||||
b = transpose(b, {3, 2, 0, 1});
|
||||
eval(a, b);
|
||||
TIMEM("strided", add, a, b, device);
|
||||
}
|
||||
|
||||
void time_comparisons() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIME(equal, a, b, device);
|
||||
TIME(greater, a, b, device);
|
||||
TIME(greater_equal, a, b, device);
|
||||
TIME(less, a, b, device);
|
||||
TIME(less_equal, a, b, device);
|
||||
}
|
||||
|
||||
void time_matvec() {
|
||||
int M = 2000, N = 200;
|
||||
auto a = random::uniform({M, N});
|
||||
auto b = random::uniform({N});
|
||||
auto c = random::uniform({M});
|
||||
eval(a, b, c);
|
||||
auto matvec = [&]() { return matmul(a, b); };
|
||||
TIME(matvec);
|
||||
|
||||
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
|
||||
TIME(matvec_transpose);
|
||||
}
|
||||
|
||||
void time_matmul() {
|
||||
int M = 1000, N = 1000, K = 1000;
|
||||
auto a = random::uniform({M, K});
|
||||
auto b = random::uniform({K, N});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIME(matmul, a, b, device);
|
||||
|
||||
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
|
||||
TIME(transpose_matmul);
|
||||
}
|
||||
|
||||
void time_reductions() {
|
||||
auto a = random::normal({10000, 1000});
|
||||
eval(a);
|
||||
auto sum_all = [&a]() { return sum(a, false); };
|
||||
TIME(sum_all);
|
||||
|
||||
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
|
||||
TIME(sum_along_0);
|
||||
|
||||
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
|
||||
TIME(sum_along_1);
|
||||
|
||||
auto prod_all = [&a]() { return prod(a, false); };
|
||||
TIME(prod_all);
|
||||
|
||||
auto all_true = [&a]() { return all(a, false); };
|
||||
TIME(all_true);
|
||||
|
||||
auto all_along_0 = [&a]() { return all(a, 0, false); };
|
||||
TIME(all_along_0);
|
||||
|
||||
auto all_along_1 = [&a]() { return all(a, 1, false); };
|
||||
TIME(all_along_1);
|
||||
|
||||
auto any_true = [&a]() { return any(a, false); };
|
||||
TIME(any_true);
|
||||
|
||||
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
|
||||
TIME(argmin_along_0);
|
||||
|
||||
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
|
||||
TIME(argmin_along_1);
|
||||
}
|
||||
|
||||
void time_gather_scatter() {
|
||||
auto a = random::normal({1000, 768});
|
||||
eval(a);
|
||||
auto indices = random::randint(0, 1000, {256});
|
||||
eval(indices);
|
||||
|
||||
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
|
||||
TIME(embedding_lookup);
|
||||
|
||||
indices = random::randint(0, 768 * 1000, {256 * 768});
|
||||
eval(indices);
|
||||
|
||||
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
|
||||
TIME(single_element_lookup);
|
||||
|
||||
indices = random::randint(0, 1000, {256});
|
||||
auto updates = random::normal({256, 1, 768});
|
||||
eval(indices, updates);
|
||||
|
||||
auto embedding_update = [&a, &indices, &updates]() {
|
||||
return scatter(a, indices, updates, 0);
|
||||
};
|
||||
TIME(embedding_update);
|
||||
|
||||
auto embedding_add = [&a, &indices, &updates]() {
|
||||
return scatter_add(a, indices, updates, 0);
|
||||
};
|
||||
TIME(embedding_add);
|
||||
|
||||
a = reshape(a, {-1});
|
||||
indices = random::randint(0, 768 * 1000, {768 * 256});
|
||||
updates = random::normal({256 * 768, 1});
|
||||
eval(a, indices, updates);
|
||||
|
||||
auto single_element_update = [&a, &indices, &updates]() {
|
||||
return scatter(a, indices, updates, 0);
|
||||
};
|
||||
TIME(single_element_update);
|
||||
|
||||
auto single_element_add = [&a, &indices, &updates]() {
|
||||
return scatter_add(a, indices, updates, 0);
|
||||
};
|
||||
TIME(single_element_add);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
time_creation_ops();
|
||||
time_type_conversions();
|
||||
time_unary_ops();
|
||||
time_binary_ops();
|
||||
time_strided_ops();
|
||||
time_random_generation();
|
||||
time_comparisons();
|
||||
time_matvec();
|
||||
time_matmul();
|
||||
time_reductions();
|
||||
time_gather_scatter();
|
||||
}
|
15
benchmarks/python/comparative/README.md
Normal file
15
benchmarks/python/comparative/README.md
Normal file
|
@ -0,0 +1,15 @@
|
|||
Microbenchmarks comparing MLX to PyTorch
|
||||
========================================
|
||||
|
||||
Implement the same microbenchmarks in MLX and PyTorch to compare and make a
|
||||
list of the biggest possible performance improvements and/or regressions.
|
||||
|
||||
Run with `python bench_mlx.py sum_axis --size 8x1024x128 --axis 2 --cpu` for
|
||||
instance to measure the times it takes to sum across the 3rd axis of the above
|
||||
tensor on the cpu.
|
||||
|
||||
`compare.py` runs several benchmarks and compares the speed-up or lack thereof
|
||||
in comparison to PyTorch.
|
||||
|
||||
Each bench script can be run with `--print-pid` to print the PID and wait for a
|
||||
key in order to ease attaching a debugger.
|
313
benchmarks/python/comparative/bench_mlx.py
Normal file
313
benchmarks/python/comparative/bench_mlx.py
Normal file
|
@ -0,0 +1,313 @@
|
|||
import argparse
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def int_or_list(x):
|
||||
try:
|
||||
return int(x)
|
||||
except ValueError:
|
||||
return [int(xi) for xi in x.split(",")]
|
||||
|
||||
|
||||
def none_or_list(x):
|
||||
if x == "":
|
||||
return None
|
||||
else:
|
||||
return [int(xi) for xi in x.split(",")]
|
||||
|
||||
|
||||
def bench(f, *args):
|
||||
for i in range(10):
|
||||
f(*args)
|
||||
|
||||
s = time.time()
|
||||
for i in range(100):
|
||||
f(*args)
|
||||
e = time.time()
|
||||
return e - s
|
||||
|
||||
|
||||
def matmul_square(x):
|
||||
y = x
|
||||
for i in range(10):
|
||||
y = y @ x
|
||||
mx.eval(y)
|
||||
return y
|
||||
|
||||
|
||||
def matmul(x, y):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(x @ y)
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def conv1d(x, y):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(mx.conv1d(x, y))
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def conv2d(x, y):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(mx.conv2d(x, y))
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def binary(op, x, y):
|
||||
for i in range(100):
|
||||
y = getattr(mx, op)(x, y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def reduction(op, axis, x):
|
||||
ys = []
|
||||
for i in range(100):
|
||||
ys.append(getattr(mx, op)(x, axis=axis))
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def softmax(axis, x):
|
||||
ys = []
|
||||
for i in range(100):
|
||||
ex = mx.exp(x - mx.max(x, axis=axis, keepdims=True))
|
||||
y = ex / mx.sum(ex, axis=axis, keepdims=True)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def softmax_fused(axis, x):
|
||||
ys = []
|
||||
for i in range(100):
|
||||
y = mx.softmax(x, axis=axis)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def relu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = mx.maximum(y, 0)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def scalar_mult(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = y * (1.0 / (1 + i))
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def cross_entropy(targets, x):
|
||||
ys = []
|
||||
for i in range(100):
|
||||
y = mx.logsumexp(x, axis=-1, keepdims=True) - mx.take_along_axis(
|
||||
x, mx.reshape(targets, (-1, 1)), axis=-1
|
||||
)
|
||||
ys.append(mx.mean(y))
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def logsumexp(axis, x):
|
||||
ys = []
|
||||
for i in range(100):
|
||||
ys.append(mx.logsumexp(x, axis=axis))
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def linear(w, b, x):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(x @ mx.transpose(w, (1, 0)) + b)
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def rope(x):
|
||||
*_, N, D = x.shape
|
||||
ys = []
|
||||
for i in range(10):
|
||||
shape = x.shape
|
||||
x = mx.reshape(x, (-1, N, D))
|
||||
positions = mx.arange(N)
|
||||
freqs = mx.exp(mx.arange(0.0, D // 2) / math.log(10000 / (D // 2 - 1)))
|
||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
||||
costheta = mx.cos(theta)
|
||||
sintheta = mx.sin(theta)
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
y = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
|
||||
y = mx.reshape(y, (-1, N, D))
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def concatenate(axis, x, y):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(mx.concatenate([x, y], axis=axis))
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def cumsum(axis, x):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(mx.cumsum(x, axis))
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def sort(axis, x):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(mx.sort(x, axis))
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def topk(axis, x):
|
||||
k = x.shape[axis] // 3
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(mx.topk(x, k, axis))
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
default=[(1024, 1024)],
|
||||
type=lambda x: list(map(int, x.split("x"))),
|
||||
help="Set the matrix size",
|
||||
action="append",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--axis",
|
||||
default=[1],
|
||||
type=int_or_list,
|
||||
help="Set a reduction axis",
|
||||
action="append",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--transpose",
|
||||
type=none_or_list,
|
||||
default=[],
|
||||
help="Permute the matrix",
|
||||
action="append",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print-pid", action="store_true", help="Print the PID and pause"
|
||||
)
|
||||
parser.add_argument("--cpu", action="store_true", help="Use the CPU")
|
||||
parser.add_argument(
|
||||
"--fused", action="store_true", help="Use fused functions where possible"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", choices=["float32", "float16", "bfloat16"], default="float32"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if len(args.size) > 1:
|
||||
args.size.pop(0)
|
||||
if len(args.axis) > 1:
|
||||
args.axis.pop(0)
|
||||
|
||||
if args.print_pid:
|
||||
print(os.getpid())
|
||||
input("Press enter to run")
|
||||
|
||||
if args.cpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
else:
|
||||
mx.set_default_device(mx.gpu)
|
||||
dtype = dict(float32=mx.float32, float16=mx.float16, bfloat16=mx.bfloat16)[
|
||||
args.dtype
|
||||
]
|
||||
xs = []
|
||||
for size in args.size:
|
||||
xs.append(mx.random.normal(size).astype(dtype))
|
||||
for i, t in enumerate(args.transpose):
|
||||
if t is None:
|
||||
continue
|
||||
xs[i] = mx.transpose(xs[i], t)
|
||||
mx.eval(xs)
|
||||
x = xs[0]
|
||||
axis = args.axis[0]
|
||||
|
||||
if args.benchmark == "matmul_square":
|
||||
print(bench(matmul_square, x))
|
||||
|
||||
elif args.benchmark == "matmul":
|
||||
print(bench(matmul, *xs))
|
||||
|
||||
elif args.benchmark == "linear":
|
||||
print(bench(linear, *xs))
|
||||
|
||||
elif args.benchmark == "sum_axis":
|
||||
print(bench(reduction, "sum", axis, x))
|
||||
|
||||
elif args.benchmark == "sum_all":
|
||||
print(bench(reduction, "sum", None, x))
|
||||
|
||||
elif args.benchmark == "argmax":
|
||||
print(bench(reduction, "argmax", axis, x))
|
||||
|
||||
elif args.benchmark == "add":
|
||||
print(bench(binary, "add", *xs))
|
||||
|
||||
elif args.benchmark == "mul":
|
||||
print(bench(binary, "multiply", *xs))
|
||||
|
||||
elif args.benchmark == "softmax":
|
||||
if args.fused:
|
||||
print(bench(softmax_fused, axis, x))
|
||||
else:
|
||||
print(bench(softmax, axis, x))
|
||||
|
||||
elif args.benchmark == "relu":
|
||||
print(bench(relu, x))
|
||||
|
||||
elif args.benchmark == "scalar_mul":
|
||||
print(bench(scalar_mult, x))
|
||||
|
||||
elif args.benchmark == "cross_entropy":
|
||||
if len(size) != 2:
|
||||
raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size")
|
||||
|
||||
targets = mx.zeros((len(x),), dtype=mx.uint32)
|
||||
print(bench(cross_entropy, targets, x))
|
||||
|
||||
elif args.benchmark == "logsumexp":
|
||||
print(bench(logsumexp, axis, x))
|
||||
|
||||
elif args.benchmark == "rope":
|
||||
print(bench(rope, x))
|
||||
|
||||
elif args.benchmark == "concatenate":
|
||||
print(bench(concatenate, axis, *xs))
|
||||
|
||||
elif args.benchmark == "cumsum":
|
||||
print(bench(cumsum, axis, *xs))
|
||||
|
||||
elif args.benchmark == "conv1d":
|
||||
print(bench(conv1d, *xs))
|
||||
|
||||
elif args.benchmark == "conv2d":
|
||||
print(bench(conv2d, *xs))
|
||||
|
||||
elif args.benchmark == "sort":
|
||||
print(bench(sort, axis, x))
|
||||
|
||||
elif args.benchmark == "topk":
|
||||
print(bench(topk, axis, x))
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown benchmark")
|
338
benchmarks/python/comparative/bench_torch.py
Normal file
338
benchmarks/python/comparative/bench_torch.py
Normal file
|
@ -0,0 +1,338 @@
|
|||
import argparse
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.mps
|
||||
|
||||
|
||||
def int_or_list(x):
|
||||
try:
|
||||
return int(x)
|
||||
except ValueError:
|
||||
return [int(xi) for xi in x.split(",")]
|
||||
|
||||
|
||||
def none_or_list(x):
|
||||
if x == "":
|
||||
return None
|
||||
else:
|
||||
return [int(xi) for xi in x.split(",")]
|
||||
|
||||
|
||||
def bench(f, *args):
|
||||
for i in range(10):
|
||||
f(*args)
|
||||
|
||||
s = time.time()
|
||||
for i in range(100):
|
||||
f(*args)
|
||||
e = time.time()
|
||||
return e - s
|
||||
|
||||
|
||||
def sync_if_needed(x):
|
||||
if x.device != torch.device("cpu"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def matmul_square(x):
|
||||
y = x
|
||||
for i in range(10):
|
||||
y = y @ x
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def matmul(x, y):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(x @ y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def conv1d(x, y):
|
||||
x = torch.transpose(x, -1, -2)
|
||||
y = torch.transpose(y, -1, -2)
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(torch.nn.functional.conv1d(x, y))
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def conv2d(x, y):
|
||||
x = torch.permute(x, (0, 3, 1, 2))
|
||||
y = torch.permute(y, (0, 3, 1, 2))
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(torch.nn.functional.conv2d(x, y))
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def binary(op, x, y):
|
||||
for i in range(100):
|
||||
y = getattr(torch, op)(x, y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def reduction(op, axis, x):
|
||||
ys = []
|
||||
for i in range(100):
|
||||
ys.append(getattr(x, op)(axis))
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def softmax(axis, x):
|
||||
ys = []
|
||||
for i in range(100):
|
||||
ex = torch.exp(x - torch.max(x, dim=axis, keepdims=True).values)
|
||||
y = ex / torch.sum(ex, dim=axis, keepdims=True)
|
||||
ys.append(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def softmax_fused(axis, x):
|
||||
ys = []
|
||||
for i in range(100):
|
||||
ys.append(torch.nn.functional.softmax(x, dim=axis))
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def relu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.nn.functional.relu(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def scalar_mult(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = y * (1.0 / (1 + i))
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def cross_entropy(targets, x):
|
||||
ys = []
|
||||
for i in range(100):
|
||||
ys.append(torch.nn.functional.cross_entropy(x, targets))
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def logsumexp(axis, x):
|
||||
ys = []
|
||||
for i in range(100):
|
||||
ys.append(torch.logsumexp(x, dim=axis))
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def linear_fused(w, b, x):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(torch.nn.functional.linear(x, w, b))
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def linear(w, b, x):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append((x @ torch.transpose(w, -2, -1)) + b)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def rope(x):
|
||||
*_, N, D = x.shape
|
||||
ys = []
|
||||
for i in range(10):
|
||||
x = x.view(-1, N, D)
|
||||
positions = torch.arange(N, device=x.device)
|
||||
freqs = 10000 ** torch.linspace(0, 1, D // 2, device=x.device)
|
||||
theta = positions[:, None] * freqs[None]
|
||||
costheta = torch.cos(theta)
|
||||
sintheta = torch.sin(theta)
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
y = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
|
||||
y = y.reshape(-1, N, D)
|
||||
ys.append(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def concatenate(axis, x, y):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(torch.cat([x, y], dim=axis))
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def cumsum(axis, x):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(x.cumsum(axis))
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sort(axis, x):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(torch.sort(x, dim=axis)[0])
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def topk(axis, x):
|
||||
k = x.shape[axis] // 3
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(torch.topk(x, k, dim=axis)[0])
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
default=[(1024, 1024)],
|
||||
type=lambda x: list(map(int, x.split("x"))),
|
||||
help="Set the matrix size",
|
||||
action="append",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--axis",
|
||||
default=[1],
|
||||
type=int_or_list,
|
||||
help="Set a reduction axis",
|
||||
action="append",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--transpose",
|
||||
type=none_or_list,
|
||||
default=[],
|
||||
help="Permute the matrix",
|
||||
action="append",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print-pid", action="store_true", help="Print the PID and pause"
|
||||
)
|
||||
parser.add_argument("--cpu", action="store_true", help="Use the CPU")
|
||||
parser.add_argument(
|
||||
"--fused", action="store_true", help="Use fused functions where possible"
|
||||
)
|
||||
parser.add_argument("--dtype", choices=["float32", "float16"], default="float32")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if len(args.size) > 1:
|
||||
args.size.pop(0)
|
||||
if len(args.axis) > 1:
|
||||
args.axis.pop(0)
|
||||
|
||||
if args.print_pid:
|
||||
print(os.getpid())
|
||||
input("Press enter to run")
|
||||
|
||||
torch.set_num_threads(1)
|
||||
device = "cpu" if args.cpu else "mps"
|
||||
dtype = dict(float32=torch.float32, float16=torch.float16)[args.dtype]
|
||||
xs = []
|
||||
for size in args.size:
|
||||
xs.append(torch.randn(*size).to(device).to(dtype))
|
||||
for i, t in enumerate(args.transpose):
|
||||
if t is None:
|
||||
continue
|
||||
xs[i] = xs[i].permute(*t)
|
||||
x = xs[0]
|
||||
axis = args.axis[0]
|
||||
|
||||
if args.benchmark == "matmul_square":
|
||||
print(bench(matmul_square, x))
|
||||
|
||||
elif args.benchmark == "matmul":
|
||||
print(bench(matmul, *xs))
|
||||
|
||||
elif args.benchmark == "linear":
|
||||
if args.fused:
|
||||
print(bench(linear_fused, *xs))
|
||||
else:
|
||||
print(bench(linear, *xs))
|
||||
|
||||
elif args.benchmark == "sum_axis":
|
||||
print(bench(reduction, "sum", axis, x))
|
||||
|
||||
elif args.benchmark == "sum_all":
|
||||
print(bench(reduction, "sum", None, x))
|
||||
|
||||
elif args.benchmark == "argmax":
|
||||
print(bench(reduction, "argmax", axis, x))
|
||||
|
||||
elif args.benchmark == "add":
|
||||
print(bench(binary, "add", *xs))
|
||||
|
||||
elif args.benchmark == "mul":
|
||||
print(bench(binary, "mul", *xs))
|
||||
|
||||
elif args.benchmark == "softmax":
|
||||
if args.fused:
|
||||
print(bench(softmax_fused, axis, x))
|
||||
else:
|
||||
print(bench(softmax, axis, x))
|
||||
|
||||
elif args.benchmark == "relu":
|
||||
print(bench(relu, x))
|
||||
|
||||
elif args.benchmark == "scalar_mul":
|
||||
print(bench(scalar_mult, x))
|
||||
|
||||
elif args.benchmark == "cross_entropy":
|
||||
if len(size) != 2:
|
||||
raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size")
|
||||
|
||||
targets = torch.zeros(len(x), dtype=torch.long).to(x.device)
|
||||
print(bench(cross_entropy, targets, x))
|
||||
|
||||
elif args.benchmark == "logsumexp":
|
||||
print(bench(logsumexp, axis, x))
|
||||
|
||||
elif args.benchmark == "rope":
|
||||
print(bench(rope, x))
|
||||
|
||||
elif args.benchmark == "concatenate":
|
||||
print(bench(concatenate, axis, *xs))
|
||||
|
||||
elif args.benchmark == "cumsum":
|
||||
print(bench(cumsum, axis, *xs))
|
||||
|
||||
elif args.benchmark == "conv1d":
|
||||
print(bench(conv1d, *xs))
|
||||
|
||||
elif args.benchmark == "conv2d":
|
||||
print(bench(conv2d, *xs))
|
||||
|
||||
elif args.benchmark == "sort":
|
||||
print(bench(sort, axis, x))
|
||||
|
||||
elif args.benchmark == "topk":
|
||||
print(bench(topk, axis, x))
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown benchmark")
|
253
benchmarks/python/comparative/compare.py
Normal file
253
benchmarks/python/comparative/compare.py
Normal file
|
@ -0,0 +1,253 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import re
|
||||
from pathlib import Path
|
||||
from subprocess import run
|
||||
|
||||
BENCH_MLX = Path(__file__).parent / "bench_mlx.py"
|
||||
BENCH_TORCH = Path(__file__).parent / "bench_torch.py"
|
||||
|
||||
|
||||
def run_or_raise(*args, **kwargs):
|
||||
try:
|
||||
result = run(*args, capture_output=True, **kwargs)
|
||||
return float(result.stdout)
|
||||
except ValueError:
|
||||
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}")
|
||||
|
||||
|
||||
def compare(args):
|
||||
t_mlx = run_or_raise(["python", BENCH_MLX] + args)
|
||||
t_torch = run_or_raise(["python", BENCH_TORCH] + args)
|
||||
|
||||
print((t_torch - t_mlx) / t_torch, " ".join(args), sep="\t")
|
||||
|
||||
|
||||
def compare_mlx_dtypes(args, dt1, dt2):
|
||||
t_mlx_dt1 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt1])
|
||||
t_mlx_dt2 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt2])
|
||||
|
||||
print((t_mlx_dt2 - t_mlx_dt1) / t_mlx_dt2, " ".join(args), sep="\t")
|
||||
|
||||
|
||||
def make_regex_search(regexes):
|
||||
compiled_regexes = list(map(re.compile, regexes))
|
||||
|
||||
def search(x):
|
||||
return (c.search(x) is not None for c in compiled_regexes)
|
||||
|
||||
return search
|
||||
|
||||
|
||||
def make_predicate(positive_filter, negative_filter):
|
||||
if positive_filter is not None:
|
||||
positive_filter_search = make_regex_search(positive_filter)
|
||||
positive_filter = lambda x: all(positive_filter_search(x))
|
||||
else:
|
||||
positive_filter = lambda x: True
|
||||
|
||||
if negative_filter is not None:
|
||||
negative_filter_search = make_regex_search(negative_filter)
|
||||
negative_filter = lambda x: not any(negative_filter_search(x))
|
||||
else:
|
||||
negative_filter = lambda x: True
|
||||
|
||||
def predicate(x):
|
||||
return positive_filter(x) and negative_filter(x)
|
||||
|
||||
return predicate
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run comparisons agains PyTorch")
|
||||
parser.add_argument(
|
||||
"--filter", "-f", help="Regex filter to select benchmarks", nargs="+"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--negative_filter", "-n", help="Regex filter to remove benchmarks", nargs="+"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx_dtypes",
|
||||
"-d",
|
||||
help="Compare mlx benchmarks between the 2 provided data types",
|
||||
nargs=2,
|
||||
)
|
||||
args, rest = parser.parse_known_args()
|
||||
|
||||
_filter = make_predicate(args.filter, args.negative_filter)
|
||||
|
||||
if args.mlx_dtypes:
|
||||
compare_filtered = (
|
||||
lambda x: compare_mlx_dtypes(
|
||||
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
|
||||
)
|
||||
if _filter(x)
|
||||
else None
|
||||
)
|
||||
else:
|
||||
compare_filtered = lambda x: compare(x.split() + rest) if _filter(x) else None
|
||||
|
||||
# Binary ops
|
||||
compare_filtered("add --size 10x1024x128 --size 1x1024x128 --cpu")
|
||||
compare_filtered("add --size 10x1024x128 --size 1x1024x128")
|
||||
compare_filtered("add --size 1024x128 --size 1x128 --cpu")
|
||||
compare_filtered("add --size 1024x128 --size 1x128")
|
||||
compare_filtered("add --size 1024x4096 --size 1x4096 --cpu")
|
||||
compare_filtered("add --size 1024x4096 --size 1x4096")
|
||||
compare_filtered("add --size 1024x4096 --size 1x1024 --transpose 1,0 --cpu")
|
||||
compare_filtered("add --size 1024x4096 --size 1x1024 --transpose 1,0")
|
||||
compare_filtered("add --size 1024x1024 --size 1024x1024 --cpu")
|
||||
compare_filtered("add --size 1024x1024 --size 1024x1024")
|
||||
compare_filtered("add --size 1024x1024 --size 1024x1024 --transpose 1,0 --cpu")
|
||||
compare_filtered("add --size 1024x1024 --size 1024x1024 --transpose 1,0")
|
||||
compare_filtered(
|
||||
"add --size 1024x1024 --size 1024x1024 --transpose 1,0 --transpose 1,0 --cpu"
|
||||
)
|
||||
compare_filtered(
|
||||
"add --size 1024x1024 --size 1024x1024 --transpose 1,0 --transpose 1,0"
|
||||
)
|
||||
|
||||
# Reduction ops
|
||||
compare_filtered("sum_all --size 10x1024x128 --cpu")
|
||||
compare_filtered("sum_all --size 10x1024x128")
|
||||
compare_filtered("sum_axis --size 16x1024x128 --axis 2 --cpu")
|
||||
compare_filtered("sum_axis --size 16x1024x128 --axis 2")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 2 --cpu")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 2")
|
||||
compare_filtered("sum_axis --size 1024x1024 --axis 1 --cpu")
|
||||
compare_filtered("sum_axis --size 1024x1024 --axis 1")
|
||||
compare_filtered("sum_axis --size 1024x1024 --axis 0 --cpu")
|
||||
compare_filtered("sum_axis --size 1024x1024 --axis 0")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 1 --cpu")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 1")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0")
|
||||
compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
|
||||
compare_filtered("argmax --size 10x1024x128 --axis 1")
|
||||
compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")
|
||||
compare_filtered("argmax --size 10x1024x128 --axis 2")
|
||||
compare_filtered("argmax --size 1024x1024 --axis 1 --cpu")
|
||||
compare_filtered("argmax --size 1024x1024 --axis 1")
|
||||
|
||||
# Matmul ops
|
||||
compare_filtered("matmul_square --size 1024x1024")
|
||||
compare_filtered("matmul_square --size 1024x1024 --cpu")
|
||||
compare_filtered("matmul_square --size 16x1024x1024")
|
||||
compare_filtered("matmul_square --size 16x1024x1024 --cpu")
|
||||
compare_filtered(
|
||||
"matmul --size 16x768x768 --size 16x768x768 --transpose= --transpose 0,2,1"
|
||||
)
|
||||
compare_filtered(
|
||||
"matmul --size 16x768x768 --size 16x768x768 --transpose= --transpose 0,2,1 --cpu"
|
||||
)
|
||||
compare_filtered(
|
||||
"matmul --size 16x768x128 --size 16x768x128 --transpose= --transpose 0,2,1"
|
||||
)
|
||||
compare_filtered(
|
||||
"matmul --size 16x768x128 --size 16x768x128 --transpose= --transpose 0,2,1 --cpu"
|
||||
)
|
||||
compare_filtered("matmul --size 512x8192 --size 8192x512")
|
||||
compare_filtered("matmul --size 512x8192 --size 8192x512 --cpu")
|
||||
# compare_filtered("matmul --size 512x131072 --size 131072x512")
|
||||
# compare_filtered("matmul --size 512x131072 --size 131072x512 --cpu")
|
||||
compare_filtered("matmul --size 8192x512 --size 512x8192")
|
||||
compare_filtered("matmul --size 8192x512 --size 512x8192 --cpu")
|
||||
# compare_filtered("matmul --size 131072x512 --size 512x512")
|
||||
# compare_filtered("matmul --size 131072x512 --size 512x512 --cpu")
|
||||
compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024")
|
||||
compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024 --cpu")
|
||||
compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024 --fused")
|
||||
compare_filtered(
|
||||
"linear --size 1024x1024 --size 1024 --size 128x1024 --fused --cpu"
|
||||
)
|
||||
|
||||
# Matvec ops
|
||||
compare_filtered("matmul --size 1x1x4096 --size 4096x4096 --cpu")
|
||||
compare_filtered("matmul --size 1x1x4096 --size 4096x4096")
|
||||
compare_filtered(
|
||||
"matmul --size 1x1x4096 --size 4096x4096 --transpose= --transpose 1,0 --cpu"
|
||||
)
|
||||
compare_filtered(
|
||||
"matmul --size 1x1x4096 --size 4096x4096 --transpose= --transpose 1,0"
|
||||
)
|
||||
compare_filtered("matmul --size 32x1x1000 --size 32x1000x128 --cpu")
|
||||
compare_filtered("matmul --size 32x1x1000 --size 32x1000x128")
|
||||
compare_filtered(
|
||||
"matmul --size 32x1x1000 --size 32x128x1000 --transpose= --transpose 0,2,1 --cpu"
|
||||
)
|
||||
compare_filtered(
|
||||
"matmul --size 32x1x1000 --size 32x128x1000 --transpose= --transpose 0,2,1"
|
||||
)
|
||||
|
||||
# Various ops
|
||||
compare_filtered("softmax --size 32x16x1024 --axis 2")
|
||||
compare_filtered("softmax --size 32x16x1024 --axis 2 --cpu")
|
||||
compare_filtered("softmax --size 32x16x1024 --axis 2 --fused")
|
||||
compare_filtered("softmax --size 32x16x1024 --axis 2 --fused --cpu")
|
||||
compare_filtered("softmax --size 2x1024x1024 --axis 1")
|
||||
compare_filtered("softmax --size 2x1024x1024 --axis 1 --cpu")
|
||||
compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused")
|
||||
compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused --cpu")
|
||||
compare_filtered("relu --size 32x16x1024")
|
||||
compare_filtered("relu --size 32x16x1024 --cpu")
|
||||
compare_filtered("scalar_mul --size 32x16x1024")
|
||||
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
|
||||
compare_filtered("cross_entropy --size 256x1024")
|
||||
compare_filtered("cross_entropy --size 256x1024 --cpu")
|
||||
compare_filtered("logsumexp --size 1024x1024 --axis 1")
|
||||
compare_filtered("logsumexp --size 1024x1024 --axis 1 --cpu")
|
||||
compare_filtered("logsumexp --size 1024x1024 --axis 0")
|
||||
compare_filtered("logsumexp --size 1024x1024 --axis 0 --cpu")
|
||||
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 2")
|
||||
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 2 --cpu")
|
||||
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 1")
|
||||
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 1 --cpu")
|
||||
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 0")
|
||||
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 0 --cpu")
|
||||
compare_filtered("concatenate --size 32x1024x128 --size 32x16x128 --axis 1")
|
||||
compare_filtered("concatenate --size 32x1024x128 --size 32x16x128 --axis 1 --cpu")
|
||||
compare_filtered("concatenate --size 32x1024x128 --size 32x1x128 --axis 1")
|
||||
compare_filtered("concatenate --size 32x1024x128 --size 32x1x128 --axis 1 --cpu")
|
||||
compare_filtered("concatenate --size 1x32x1024x128 --size 1x32x1x128 --axis 2")
|
||||
< |