mirror of
https://github.com/ml-explore/mlx.git
synced 2024-09-15 10:04:00 +02:00
Einsum (#1269)
* einsum initial * fix comma break * sum axis was wrong * small cleanups * python binding * changed bindings to resemble numpy * remove todo comment * comment changes * add count of operands/inputs * fail fast if operands list is empty * ignore comma if no output * einsum path matching numpy * getting somewhere with path * remove print * it passes the first test * moved einsum tests to seperate file * seperated einsum path * moved einsum naive * remove space from equation * fast fail if no operands passed * update tests and remove printf * small cleanup * some more cleanups * removed python helper file * ack * utilize std for finding min in vector * duplicate def * remove the tuple as it was unreadable * moved einsum_naive back to ops * remaining isn't needed * avoid creating another set * cleanup * greedy path, start of naive einsum * more einsum * fix some bugs * some more fixes, tests pass * benchmark * some simplify * fix einsum and test Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com> * add a bunch more tests and fix a bunch more bugs * some docs nits --------- Co-authored-by: dc-dc-dc <dgcruz983@gmail.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
7f914365fd
commit
baf9fa5f42
|
@ -10,7 +10,7 @@ MLX was developed with contributions from the following individuals:
|
|||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
|
||||
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||
|
|
84
benchmarks/python/einsum_bench.py
Normal file
84
benchmarks/python/einsum_bench.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def timeit(fn, its=100, args=[]):
|
||||
for _ in range(5):
|
||||
fn(*args)
|
||||
tic = time.perf_counter()
|
||||
for _ in range(its):
|
||||
fn(*args)
|
||||
toc = time.perf_counter()
|
||||
return 1e3 * (toc - tic) / its
|
||||
|
||||
|
||||
def time_little_einsum_path():
|
||||
subscripts = "ik,kj->ij"
|
||||
x = mx.ones((32, 32))
|
||||
y = mx.ones((32, 32))
|
||||
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
|
||||
|
||||
x = np.array(x)
|
||||
y = np.array(y)
|
||||
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
|
||||
print("Timing little einsum path...")
|
||||
print(f"MLX ... {mx_time:.3f} ms")
|
||||
print(f"NumPy... {np_time:.3f} ms")
|
||||
|
||||
|
||||
def time_big_einsum_path():
|
||||
chars = list("abcdefgh")
|
||||
char_to_dim = {c: v for v, c in enumerate(chars)}
|
||||
|
||||
num_inputs = 10
|
||||
inputs = []
|
||||
subscripts = []
|
||||
for _ in range(num_inputs):
|
||||
subscript = np.random.choice(chars, size=5, replace=False).tolist()
|
||||
subscripts.append("".join(subscript))
|
||||
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
|
||||
subscripts = ",".join(subscripts)
|
||||
|
||||
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
|
||||
|
||||
inputs = [mx.array(x) for x in inputs]
|
||||
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
|
||||
print("Timing big einsum path...")
|
||||
print(f"MLX ... {mx_time:.3f} ms")
|
||||
print(f"NumPy... {np_time:.3f} ms")
|
||||
|
||||
|
||||
def time_attention():
|
||||
def regular_attention(x):
|
||||
# shape [batch, sequence, num_heads, head_dim]
|
||||
queries, keys, values = x, x, x
|
||||
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
|
||||
mx.eval(output)
|
||||
|
||||
def einsum_attention(x):
|
||||
# shape [batch, sequence, num_heads, head_dim]
|
||||
queries, keys, values = x, x, x
|
||||
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
output = mx.einsum("ijtu,iujk->itjk", scores, values)
|
||||
mx.eval(output)
|
||||
|
||||
x = mx.random.uniform(shape=(8, 512, 32, 128))
|
||||
|
||||
regular_time = timeit(regular_attention, args=(x,))
|
||||
ein_time = timeit(einsum_attention, args=(x,))
|
||||
print("Timing einsum attention...")
|
||||
print(f"Regular ... {regular_time:.3f} ms")
|
||||
print(f"Einsum ... {ein_time:.3f} ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_little_einsum_path()
|
||||
time_big_einsum_path()
|
||||
time_attention()
|
|
@ -57,6 +57,8 @@ Operations
|
|||
diagonal
|
||||
divide
|
||||
divmod
|
||||
einsum
|
||||
einsum_path
|
||||
equal
|
||||
erf
|
||||
erfinv
|
||||
|
|
|
@ -6,6 +6,7 @@ target_sources(
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
|
|
859
mlx/einsum.cpp
Normal file
859
mlx/einsum.cpp
Normal file
|
@ -0,0 +1,859 @@
|
|||
// Copyright © 2024 Apple Inc.
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/einsum.h"
|
||||
#include "mlx/ops.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// The MLX einsum implementation is based on NumPy (which is based on
|
||||
// opt_einsum):
|
||||
// https://github.com/numpy/numpy/blob/1d49c7f7ff527c696fc26ab2278ad51632a66660/numpy/_core/einsumfunc.py#L743
|
||||
// https://github.com/dgasmith/opt_einsum
|
||||
|
||||
using CharSet = std::unordered_set<char>;
|
||||
|
||||
// A helper struct to hold the string and set
|
||||
// representation of a subscript to avoid needing
|
||||
// to recompute the set
|
||||
struct Subscript {
|
||||
Subscript(std::string str, CharSet set)
|
||||
: str(std::move(str)), set(std::move(set)) {};
|
||||
std::string str;
|
||||
CharSet set;
|
||||
};
|
||||
|
||||
struct PathInfo {
|
||||
size_t naive_cost;
|
||||
size_t naive_scaling;
|
||||
size_t optimized_cost;
|
||||
size_t optimized_scaling;
|
||||
size_t largest_term;
|
||||
};
|
||||
|
||||
struct PathNode {
|
||||
PathNode(
|
||||
std::vector<Subscript> inputs,
|
||||
Subscript output,
|
||||
std::vector<int> positions)
|
||||
: inputs(std::move(inputs)),
|
||||
output(std::move(output)),
|
||||
positions(std::move(positions)) {};
|
||||
|
||||
std::vector<Subscript> inputs;
|
||||
Subscript output;
|
||||
|
||||
std::vector<int> positions;
|
||||
};
|
||||
|
||||
// Parse the comma separated subscripts into a vector of strings. If the
|
||||
// output subscripts are missing they are inferred.
|
||||
//
|
||||
// For example:
|
||||
// "ij,jk -> ik" becomes {{"ij", "jk"}, "ik"}
|
||||
// "ij,jk" becomes {{"ij", "jk"}, "ik"}
|
||||
std::pair<std::vector<std::string>, std::string> parse(std::string subscripts) {
|
||||
std::string lhs, rhs;
|
||||
|
||||
// Start by removing all white space
|
||||
subscripts.erase(
|
||||
std::remove(subscripts.begin(), subscripts.end(), ' '), subscripts.end());
|
||||
|
||||
if (auto pos = subscripts.find("->"); pos != std::string::npos) {
|
||||
// Explicit mode
|
||||
lhs = subscripts.substr(0, pos);
|
||||
rhs = subscripts.substr(pos + 2);
|
||||
} else {
|
||||
// Implicit mode:
|
||||
// - repeats are summed
|
||||
// - remaining output axes are ordered alphabetically
|
||||
lhs = subscripts;
|
||||
std::unordered_map<char, int> temp;
|
||||
for (auto& c : subscripts) {
|
||||
if (c == ',') {
|
||||
continue;
|
||||
}
|
||||
auto inserted = temp.insert({c, 0});
|
||||
inserted.first->second++;
|
||||
}
|
||||
for (auto& k : temp) {
|
||||
if (k.second == 1) {
|
||||
rhs += k.first;
|
||||
}
|
||||
}
|
||||
std::sort(rhs.begin(), rhs.end());
|
||||
}
|
||||
std::vector<std::string> input_list;
|
||||
std::stringstream ss(lhs);
|
||||
std::string token;
|
||||
while (getline(ss, token, ',')) {
|
||||
input_list.push_back(token);
|
||||
}
|
||||
return {input_list, rhs};
|
||||
}
|
||||
|
||||
// Check if two sets are disjoint
|
||||
bool disjoint(const CharSet& x, const CharSet& y) {
|
||||
for (auto& c : x) {
|
||||
if (y.find(c) != y.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t term_size(const T& term, std::unordered_map<char, int> dict) {
|
||||
size_t size = 1;
|
||||
for (auto c : term) {
|
||||
size *= dict[c];
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
size_t flop_count(
|
||||
const CharSet& term,
|
||||
bool inner,
|
||||
int num_terms,
|
||||
std::unordered_map<char, int> dict) {
|
||||
size_t size = term_size(term, dict);
|
||||
auto op_factor = 1;
|
||||
if ((num_terms - 1) > op_factor) {
|
||||
op_factor = num_terms - 1;
|
||||
}
|
||||
if (inner) {
|
||||
op_factor += 1;
|
||||
}
|
||||
return size * op_factor;
|
||||
}
|
||||
|
||||
std::pair<size_t, int> compute_cost_and_scaling(
|
||||
const std::vector<Subscript>& inputs,
|
||||
const Subscript& output,
|
||||
std::unordered_map<char, int> dim_map) {
|
||||
CharSet contractions;
|
||||
for (auto& in : inputs) {
|
||||
contractions.insert(in.set.begin(), in.set.end());
|
||||
}
|
||||
|
||||
bool inner = false;
|
||||
for (auto c : contractions) {
|
||||
if (output.set.find(c) == output.set.end()) {
|
||||
inner = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto cost = flop_count(contractions, inner, inputs.size(), dim_map);
|
||||
return {cost, contractions.size()};
|
||||
}
|
||||
|
||||
std::tuple<std::vector<PathNode>, size_t, int> greedy_path(
|
||||
std::vector<Subscript> inputs,
|
||||
const Subscript& output,
|
||||
std::unordered_map<char, int> dim_map,
|
||||
size_t cost_limit,
|
||||
size_t memory_limit) {
|
||||
// Helper struct for building the greedy path
|
||||
struct Contraction {
|
||||
Contraction(
|
||||
size_t size,
|
||||
size_t cost,
|
||||
CharSet output,
|
||||
int dims,
|
||||
int x,
|
||||
int y)
|
||||
: size(size),
|
||||
cost(cost),
|
||||
output(std::move(output)),
|
||||
dims(dims),
|
||||
x(x),
|
||||
y(y) {};
|
||||
|
||||
int64_t size; // Size difference, can be negative
|
||||
size_t cost;
|
||||
CharSet output;
|
||||
int dims; // Number of dimensions in the contraction
|
||||
int x;
|
||||
int y;
|
||||
};
|
||||
|
||||
// Start by iterating over all possible combinations
|
||||
std::vector<std::pair<int, int>> pos_pairs;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
for (int j = i + 1; j < inputs.size(); ++j) {
|
||||
pos_pairs.emplace_back(i, j);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<PathNode> path;
|
||||
std::vector<Contraction> possible_contractions;
|
||||
size_t path_cost = 0;
|
||||
int path_scaling = 0;
|
||||
auto num_in = inputs.size();
|
||||
for (int i = 0; i < num_in - 1; ++i) {
|
||||
auto add_contraction = [&](int p1, int p2) {
|
||||
CharSet new_term;
|
||||
CharSet contractions(inputs[p1].set.begin(), inputs[p1].set.end());
|
||||
contractions.insert(inputs[p2].set.begin(), inputs[p2].set.end());
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
if (i == p1 || i == p2) {
|
||||
continue;
|
||||
}
|
||||
auto& in = inputs[i].set;
|
||||
for (auto c : in) {
|
||||
if (contractions.find(c) != contractions.end()) {
|
||||
new_term.insert(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto c : output.set) {
|
||||
if (contractions.find(c) != contractions.end()) {
|
||||
new_term.insert(c);
|
||||
}
|
||||
}
|
||||
|
||||
// Ignore if:
|
||||
// - The size of the new result is greater than the memory limit
|
||||
// - The cost is larger than the naive cost
|
||||
auto new_size = term_size(new_term, dim_map);
|
||||
if (new_size > memory_limit) {
|
||||
return;
|
||||
}
|
||||
int64_t removed_size = term_size(inputs[p1].set, dim_map) +
|
||||
term_size(inputs[p2].set, dim_map) - new_size;
|
||||
|
||||
bool inner = contractions.size() > new_term.size();
|
||||
auto cost = flop_count(contractions, inner, 2, dim_map);
|
||||
if (path_cost + cost > cost_limit) {
|
||||
return;
|
||||
}
|
||||
possible_contractions.emplace_back(
|
||||
removed_size, cost, std::move(new_term), contractions.size(), p1, p2);
|
||||
};
|
||||
|
||||
for (auto& [p1, p2] : pos_pairs) {
|
||||
// Ignore outer products
|
||||
if (!disjoint(inputs[p1].set, inputs[p2].set)) {
|
||||
add_contraction(p1, p2);
|
||||
}
|
||||
}
|
||||
|
||||
// If there's nothing in the contraction list,
|
||||
// go over the pairs again without ignoring outer products
|
||||
if (possible_contractions.empty()) {
|
||||
for (auto& [p1, p2] : pos_pairs) {
|
||||
add_contraction(p1, p2);
|
||||
}
|
||||
}
|
||||
|
||||
if (possible_contractions.empty()) {
|
||||
// Default to naive einsum for the remaining inputs
|
||||
std::vector<int> positions(inputs.size());
|
||||
std::iota(positions.begin(), positions.end(), 0);
|
||||
auto [cost, scale] = compute_cost_and_scaling(inputs, output, dim_map);
|
||||
path.emplace_back(std::move(inputs), output, std::move(positions));
|
||||
|
||||
path_cost += cost;
|
||||
path_scaling = std::max(scale, path_scaling);
|
||||
break;
|
||||
}
|
||||
|
||||
// Find the best contraction
|
||||
auto& best = *std::min_element(
|
||||
possible_contractions.begin(),
|
||||
possible_contractions.end(),
|
||||
[](const auto& x, const auto& y) {
|
||||
return x.size > y.size || (x.size == y.size && x.cost < y.cost);
|
||||
});
|
||||
path_scaling = std::max(best.dims, path_scaling);
|
||||
|
||||
// Construct the output subscripts
|
||||
std::string out_str(best.output.begin(), best.output.end());
|
||||
// TODO, sorting by dimension size seems suboptimal?
|
||||
std::sort(out_str.begin(), out_str.end(), [&dim_map](auto x, auto y) {
|
||||
return dim_map[x] < dim_map[y];
|
||||
});
|
||||
Subscript new_output(std::move(out_str), std::move(best.output));
|
||||
|
||||
// Add the chosen contraction to the path
|
||||
{
|
||||
std::vector<Subscript> in_terms;
|
||||
in_terms.push_back(std::move(inputs[best.x]));
|
||||
in_terms.push_back(std::move(inputs[best.y]));
|
||||
path.emplace_back(
|
||||
std::move(in_terms), new_output, std::vector<int>{best.x, best.y});
|
||||
}
|
||||
// Remove used terms
|
||||
inputs.erase(inputs.begin() + best.y);
|
||||
inputs.erase(inputs.begin() + best.x);
|
||||
|
||||
// Add the new result
|
||||
inputs.push_back(std::move(new_output));
|
||||
|
||||
// Update the existing contractions based on the selected one
|
||||
std::vector<Contraction> updated_contractions;
|
||||
for (auto& contraction : possible_contractions) {
|
||||
// Drop contractions which contain either selected term
|
||||
if (contraction.x == best.x || contraction.x == best.y ||
|
||||
contraction.y == best.x || contraction.y == best.y) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Update the positions of other contractions
|
||||
int x =
|
||||
contraction.x - (contraction.x > best.x) - (contraction.x > best.y);
|
||||
int y =
|
||||
contraction.y - (contraction.y > best.x) - (contraction.y > best.y);
|
||||
contraction.x = x;
|
||||
contraction.y = y;
|
||||
updated_contractions.push_back(std::move(contraction));
|
||||
}
|
||||
|
||||
pos_pairs.clear();
|
||||
for (int i = 0; i < inputs.size() - 1; ++i) {
|
||||
pos_pairs.emplace_back(i, inputs.size() - 1);
|
||||
}
|
||||
path_cost += best.cost;
|
||||
|
||||
possible_contractions = std::move(updated_contractions);
|
||||
}
|
||||
return {path, path_cost, path_scaling};
|
||||
}
|
||||
|
||||
// Assumes inputs have already have had repeats and single axis sums collapsed
|
||||
bool can_dot(const std::vector<Subscript>& inputs, const Subscript& output) {
|
||||
if (inputs.size() != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto c : inputs[0].set) {
|
||||
// Use batched tensordot if anything is being contracted
|
||||
if (output.set.find(c) == output.set.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
array batch_tensordot(
|
||||
array a,
|
||||
array b,
|
||||
std::vector<int> a_contract,
|
||||
std::vector<int> a_batch,
|
||||
std::vector<int> a_concat,
|
||||
std::vector<int> b_contract,
|
||||
std::vector<int> b_batch,
|
||||
std::vector<int> b_concat,
|
||||
StreamOrDevice s) {
|
||||
// Broadcast contracting dimensions
|
||||
{
|
||||
auto a_shape = a.shape();
|
||||
auto b_shape = b.shape();
|
||||
for (int i = 0; i < a_contract.size(); ++i) {
|
||||
auto d = std::max(a.shape(a_contract[i]), b.shape(b_contract[i]));
|
||||
a_shape[a_contract[i]] = d;
|
||||
b_shape[b_contract[i]] = d;
|
||||
}
|
||||
a = broadcast_to(a, a_shape, s);
|
||||
b = broadcast_to(b, b_shape, s);
|
||||
}
|
||||
auto transpose_reshape = [&s](
|
||||
const array& x,
|
||||
const std::vector<int>& i,
|
||||
const std::vector<int>& j,
|
||||
const std::vector<int>& k) {
|
||||
std::vector<int> reorder(i.begin(), i.end());
|
||||
reorder.insert(reorder.end(), j.begin(), j.end());
|
||||
reorder.insert(reorder.end(), k.begin(), k.end());
|
||||
|
||||
int size1 = 1;
|
||||
for (auto s : j) {
|
||||
size1 *= x.shape(s);
|
||||
}
|
||||
|
||||
int size2 = 1;
|
||||
for (auto s : k) {
|
||||
size2 *= x.shape(s);
|
||||
}
|
||||
|
||||
std::vector<int> shape;
|
||||
for (auto ax : i) {
|
||||
shape.push_back(x.shape(ax));
|
||||
}
|
||||
shape.push_back(size1);
|
||||
shape.push_back(size2);
|
||||
|
||||
return reshape(transpose(x, reorder, s), std::move(shape), s);
|
||||
};
|
||||
|
||||
std::vector<int> out_shape;
|
||||
for (auto ax : a_batch) {
|
||||
out_shape.push_back(a.shape(ax));
|
||||
}
|
||||
for (auto ax : a_concat) {
|
||||
out_shape.push_back(a.shape(ax));
|
||||
}
|
||||
for (auto ax : b_concat) {
|
||||
out_shape.push_back(b.shape(ax));
|
||||
}
|
||||
|
||||
a = transpose_reshape(a, a_batch, a_concat, a_contract);
|
||||
b = transpose_reshape(b, b_batch, b_contract, b_concat);
|
||||
|
||||
return reshape(matmul(a, b, s), std::move(out_shape), s);
|
||||
}
|
||||
|
||||
// Collapse repeated subscripts and return the resulting array. The subscript
|
||||
// is also updated in place. For example:
|
||||
// - Given an input with shape (4, 4) and subscript "ii", returns
|
||||
// the diagonal of shape (4,) and updates the subscript to "i".
|
||||
// - Given an input with shape (4, 2, 4, 2) and subscript "ijij",
|
||||
// returns an output with shape (4, 2) and updates the subscript
|
||||
// to "ij".
|
||||
array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) {
|
||||
// Build a list of (repeat chars, num repeats)
|
||||
auto& str = subscript.str;
|
||||
std::vector<std::pair<char, int>> repeats;
|
||||
std::string new_str;
|
||||
{
|
||||
std::string repeat_str;
|
||||
std::string no_repeat_str;
|
||||
std::unordered_map<char, int> counts;
|
||||
for (int i = 0; i < str.size(); ++i) {
|
||||
auto [it, _] = counts.insert({str[i], 0});
|
||||
it->second++;
|
||||
}
|
||||
|
||||
for (auto& v : counts) {
|
||||
if (v.second > 1) {
|
||||
repeats.emplace_back(v.first, v.second);
|
||||
repeat_str += v.first;
|
||||
}
|
||||
}
|
||||
for (auto& c : str) {
|
||||
if (counts[c] == 1) {
|
||||
no_repeat_str += c;
|
||||
}
|
||||
}
|
||||
new_str = repeat_str + no_repeat_str;
|
||||
}
|
||||
|
||||
// Build the inputs for gather
|
||||
auto slice_sizes = in.shape();
|
||||
std::vector<int> axes;
|
||||
std::vector<array> indices;
|
||||
int n_expand = repeats.size();
|
||||
for (auto [c, v] : repeats) {
|
||||
for (int i = 0; i < str.size(); ++i) {
|
||||
if (str[i] == c) {
|
||||
slice_sizes[i] = 1;
|
||||
axes.push_back(i);
|
||||
}
|
||||
}
|
||||
std::vector<int> idx_shape(n_expand--, 1);
|
||||
idx_shape[0] = in.shape(axes.back());
|
||||
auto idx = reshape(arange(in.shape(axes.back()), s), idx_shape, s);
|
||||
for (int i = 0; i < v; ++i) {
|
||||
indices.push_back(idx);
|
||||
}
|
||||
}
|
||||
|
||||
in = gather(in, indices, axes, slice_sizes, s);
|
||||
|
||||
// Update subscript string with removed dups
|
||||
str = new_str;
|
||||
|
||||
// Squeeze singleton dimensions left over from the gather
|
||||
for (auto& ax : axes) {
|
||||
ax += indices[0].ndim();
|
||||
}
|
||||
|
||||
return squeeze(in, axes, s);
|
||||
}
|
||||
|
||||
// Collapse repeat indices and sum single dimensions.
|
||||
// For example:
|
||||
// - "aa" becomes "a"
|
||||
// - "ij,jk->k" becoms "j,jk->k"
|
||||
void preprocess_einsum_inputs(
|
||||
std::vector<Subscript>& inputs,
|
||||
const Subscript& output,
|
||||
const std::vector<int>& positions,
|
||||
std::vector<array>& operands,
|
||||
StreamOrDevice s) {
|
||||
// Collapse repeat indices
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
if (in.set.size() < in.str.size()) {
|
||||
operands[positions[i]] = collapse_repeats(operands[positions[i]], in, s);
|
||||
}
|
||||
}
|
||||
|
||||
// Sum indices that are only in a single input
|
||||
{
|
||||
std::unordered_map<char, int> counts;
|
||||
for (auto& in : inputs) {
|
||||
for (auto c : in.set) {
|
||||
auto inserted = counts.insert({c, 0});
|
||||
inserted.first->second++;
|
||||
}
|
||||
}
|
||||
for (auto c : output.set) {
|
||||
auto inserted = counts.insert({c, 0});
|
||||
inserted.first->second++;
|
||||
}
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
std::vector<int> sum_axes;
|
||||
for (int ax = 0; ax < in.str.size(); ++ax) {
|
||||
if (counts[in.str[ax]] == 1) {
|
||||
sum_axes.push_back(ax);
|
||||
}
|
||||
}
|
||||
if (!sum_axes.empty()) {
|
||||
operands[positions[i]] =
|
||||
sum(operands[positions[i]], sum_axes, false, s);
|
||||
}
|
||||
for (auto it = sum_axes.rbegin(); it != sum_axes.rend(); ++it) {
|
||||
in.set.erase(in.str[*it]);
|
||||
in.str.erase(in.str.begin() + *it);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
array einsum_naive(
|
||||
std::vector<Subscript> inputs,
|
||||
const Subscript& output,
|
||||
const std::vector<int>& positions,
|
||||
std::vector<array> operands,
|
||||
StreamOrDevice s) {
|
||||
// Map each character to an axis
|
||||
std::unordered_map<char, int> char_to_ax;
|
||||
for (auto& in : inputs) {
|
||||
for (auto c : in.str) {
|
||||
char_to_ax.insert({c, char_to_ax.size()});
|
||||
}
|
||||
}
|
||||
|
||||
// Expand and transpose inputs as needed
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
int pos = positions[i];
|
||||
auto& op = operands[pos];
|
||||
|
||||
// Add missing dimensions at the end
|
||||
if (op.ndim() != char_to_ax.size()) {
|
||||
auto shape = op.shape();
|
||||
shape.insert(shape.end(), char_to_ax.size() - shape.size(), 1);
|
||||
op = reshape(op, std::move(shape), s);
|
||||
}
|
||||
|
||||
// Transpose:
|
||||
// - Build a vector of (char, ax) pairs for the current input
|
||||
// - Sort the vector by the canonical axis in char_to_ax
|
||||
// - Extract the sorted axis to get transpose order
|
||||
std::vector<std::pair<char, int>> str_ax;
|
||||
for (auto c : inputs[i].str) {
|
||||
str_ax.emplace_back(c, str_ax.size());
|
||||
}
|
||||
for (auto [c, ax] : char_to_ax) {
|
||||
if (inputs[i].set.find(c) == inputs[i].set.end()) {
|
||||
str_ax.emplace_back(c, str_ax.size());
|
||||
}
|
||||
}
|
||||
std::sort(
|
||||
str_ax.begin(),
|
||||
str_ax.end(),
|
||||
[&char_to_ax](const auto& x, const auto& y) {
|
||||
return char_to_ax[x.first] < char_to_ax[y.first];
|
||||
});
|
||||
|
||||
// Skip the transpose if not needed
|
||||
if (std::is_sorted(
|
||||
str_ax.begin(), str_ax.end(), [](const auto& x, const auto& y) {
|
||||
return x.second < y.second;
|
||||
})) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<int> reorder;
|
||||
for (auto [c, ax] : str_ax) {
|
||||
reorder.push_back(ax);
|
||||
}
|
||||
op = transpose(op, reorder, s);
|
||||
}
|
||||
|
||||
// Multiply and sum
|
||||
auto out = operands[positions[0]];
|
||||
for (int i = 1; i < positions.size(); ++i) {
|
||||
out = multiply(out, operands[positions[i]], s);
|
||||
}
|
||||
std::vector<int> sum_axes;
|
||||
for (auto [c, ax] : char_to_ax) {
|
||||
if (output.set.find(c) == output.set.end()) {
|
||||
sum_axes.push_back(ax);
|
||||
}
|
||||
}
|
||||
if (!sum_axes.empty()) {
|
||||
out = sum(out, sum_axes, false, s);
|
||||
}
|
||||
|
||||
// Transpose output if needed
|
||||
std::vector<int> reorder;
|
||||
for (auto c : output.str) {
|
||||
reorder.push_back(char_to_ax[c]);
|
||||
}
|
||||
for (auto& r : reorder) {
|
||||
int offset = 0;
|
||||
for (auto s : sum_axes) {
|
||||
if (r > s) {
|
||||
offset++;
|
||||
}
|
||||
}
|
||||
r -= offset;
|
||||
}
|
||||
return transpose(out, reorder, s);
|
||||
}
|
||||
|
||||
std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
|
||||
const std::string& subscripts,
|
||||
const std::vector<array>& operands,
|
||||
const std::string& fn_name) {
|
||||
if (operands.size() == 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << fn_name << "] At least one operand is required.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto [in_subscripts, out_subscript] = parse(subscripts);
|
||||
|
||||
if (operands.size() != in_subscripts.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << fn_name << "] Number of operands, " << operands.size()
|
||||
<< ", does not match number of input subscripts, "
|
||||
<< in_subscripts.size();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto check_letters = [&](const auto& subscript) {
|
||||
for (auto c : subscript) {
|
||||
if (!isalpha(c)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << fn_name << "] Subscripts must be letters, but got '" << c
|
||||
<< "'.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
};
|
||||
for (auto& in : in_subscripts) {
|
||||
check_letters(in);
|
||||
}
|
||||
check_letters(out_subscript);
|
||||
|
||||
CharSet out_set(out_subscript.begin(), out_subscript.end());
|
||||
if (out_set.size() != out_subscript.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << fn_name << "] Repeat indices not allowed in output.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
Subscript output(out_subscript, std::move(out_set));
|
||||
|
||||
std::unordered_map<char, int> dim_map;
|
||||
std::vector<Subscript> inputs;
|
||||
for (int i = 0; i < in_subscripts.size(); ++i) {
|
||||
auto& in = in_subscripts[i];
|
||||
CharSet in_set(in.begin(), in.end());
|
||||
inputs.emplace_back(in, in_set);
|
||||
|
||||
if (in.size() != operands[i].ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << fn_name << "] Invalid number of subscripts " << in.size()
|
||||
<< " for input " << i << " with " << operands[i].ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Check repeat subscripts are valid
|
||||
if (in_set.size() < in.size()) {
|
||||
std::unordered_map<char, int> local_dims;
|
||||
for (int j = 0; j < in.size(); ++j) {
|
||||
auto dim = operands[i].shape(j);
|
||||
auto inserted = local_dims.insert({in[j], dim});
|
||||
if (!inserted.second) {
|
||||
if (inserted.first->second != dim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << fn_name << "] Dimensions of repeated subscripts "
|
||||
<< "do not have the same size (" << inserted.first->second
|
||||
<< " != " << dim << ").";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < in.size(); j++) {
|
||||
auto c = in[j];
|
||||
auto dim = operands[i].shape(j);
|
||||
auto inserted = dim_map.insert({c, dim});
|
||||
auto& in_dim = inserted.first->second;
|
||||
if (dim != 1 && in_dim != 1 && in_dim != dim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << fn_name << "] Cannot broadcast dimension " << j
|
||||
<< " of input " << i << " with shape " << operands[i].shape()
|
||||
<< " to size " << in_dim << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
// Ensure the broadcasted size is used
|
||||
in_dim = std::max(in_dim, dim);
|
||||
}
|
||||
}
|
||||
|
||||
size_t max_size = term_size(out_subscript, dim_map);
|
||||
for (auto& in : in_subscripts) {
|
||||
max_size = std::max(max_size, term_size(in, dim_map));
|
||||
}
|
||||
|
||||
PathInfo path_info;
|
||||
|
||||
// Get the full naive cost
|
||||
std::tie(path_info.naive_cost, path_info.naive_scaling) =
|
||||
compute_cost_and_scaling(inputs, output, dim_map);
|
||||
|
||||
// Calculate the path
|
||||
std::vector<PathNode> path;
|
||||
if (inputs.size() <= 2) {
|
||||
std::vector<int> positions(in_subscripts.size());
|
||||
std::iota(positions.begin(), positions.end(), 0);
|
||||
path.emplace_back(
|
||||
std::move(inputs), std::move(output), std::move(positions));
|
||||
} else {
|
||||
std::tie(path, path_info.optimized_cost, path_info.optimized_scaling) =
|
||||
greedy_path(inputs, output, dim_map, path_info.naive_cost, max_size);
|
||||
// Set the final output subscript to the actual output
|
||||
path.back().output = std::move(output);
|
||||
}
|
||||
return {path, path_info};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::pair<std::vector<std::vector<int>>, std::string> einsum_path(
|
||||
const std::string& subscripts,
|
||||
const std::vector<array>& operands) {
|
||||
auto [path, path_info] =
|
||||
einsum_path_helper(subscripts, operands, "einsum_path");
|
||||
|
||||
std::vector<std::vector<int>> pos_path;
|
||||
for (auto& p : path) {
|
||||
pos_path.push_back(p.positions);
|
||||
}
|
||||
|
||||
std::ostringstream path_print;
|
||||
path_print << " Complete contraction: " << subscripts << "\n"
|
||||
<< " Naive scaling: " << path_info.naive_scaling << "\n"
|
||||
<< " Optimized scaling: " << path_info.optimized_scaling
|
||||
<< "\n"
|
||||
<< " Naive FLOP count: " << path_info.naive_cost << "\n"
|
||||
<< " Optimized FLOP count: " << path_info.optimized_cost << "\n";
|
||||
// TODO add more info here
|
||||
return {pos_path, path_print.str()};
|
||||
}
|
||||
|
||||
array einsum(
|
||||
const std::string& subscripts,
|
||||
const std::vector<array>& operands,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto [path, path_info] = einsum_path_helper(subscripts, operands, "einsum");
|
||||
auto inputs = operands;
|
||||
for (auto& node : path) {
|
||||
preprocess_einsum_inputs(
|
||||
node.inputs, node.output, node.positions, inputs, s);
|
||||
|
||||
if (can_dot(node.inputs, node.output)) {
|
||||
auto& in_a = node.inputs[0];
|
||||
auto& in_b = node.inputs[1];
|
||||
auto& out = node.output;
|
||||
|
||||
std::vector<int> a_contract;
|
||||
std::vector<int> a_batch;
|
||||
std::vector<int> a_concat;
|
||||
for (int i = 0; i < in_a.str.size(); ++i) {
|
||||
auto c = in_a.str[i];
|
||||
if (out.set.find(c) == out.set.end()) {
|
||||
// Not in the output, contraction
|
||||
a_contract.push_back(i);
|
||||
} else if (in_b.set.find(c) != in_b.set.end()) {
|
||||
// Not a contraction but in both inputs, batch dim
|
||||
a_batch.push_back(i);
|
||||
} else {
|
||||
// Not a batch dim or contract dim, so concat dim
|
||||
a_concat.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> b_contract;
|
||||
std::vector<int> b_batch;
|
||||
std::vector<int> b_concat;
|
||||
for (auto a_i : a_contract) {
|
||||
b_contract.push_back(in_b.str.find(in_a.str[a_i]));
|
||||
}
|
||||
for (auto a_i : a_batch) {
|
||||
b_batch.push_back(in_b.str.find(in_a.str[a_i]));
|
||||
}
|
||||
for (int i = 0; i < in_b.str.size(); ++i) {
|
||||
auto c = in_b.str[i];
|
||||
if (out.set.find(c) != out.set.end() &&
|
||||
in_a.set.find(c) == in_a.set.end()) {
|
||||
b_concat.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
auto& a = inputs[node.positions[0]];
|
||||
auto& b = inputs[node.positions[1]];
|
||||
|
||||
std::unordered_map<char, int> char_map;
|
||||
for (auto i : a_batch) {
|
||||
char_map.insert({in_a.str[i], char_map.size()});
|
||||
}
|
||||
for (auto i : a_concat) {
|
||||
char_map.insert({in_a.str[i], char_map.size()});
|
||||
}
|
||||
for (auto i : b_concat) {
|
||||
char_map.insert({in_b.str[i], char_map.size()});
|
||||
}
|
||||
inputs.emplace_back(batch_tensordot(
|
||||
a,
|
||||
b,
|
||||
std::move(a_contract),
|
||||
std::move(a_batch),
|
||||
std::move(a_concat),
|
||||
std::move(b_contract),
|
||||
std::move(b_batch),
|
||||
std::move(b_concat),
|
||||
s));
|
||||
|
||||
std::vector<int> reorder;
|
||||
for (auto c : node.output.str) {
|
||||
reorder.push_back(char_map[c]);
|
||||
}
|
||||
inputs.back() = transpose(inputs.back(), reorder, s);
|
||||
|
||||
} else {
|
||||
inputs.emplace_back(
|
||||
einsum_naive(node.inputs, node.output, node.positions, inputs, s));
|
||||
}
|
||||
|
||||
// Positions are always sorted increasing, so start from the back
|
||||
for (auto it = node.positions.rbegin(); it != node.positions.rend(); ++it) {
|
||||
inputs.erase(inputs.begin() + *it);
|
||||
}
|
||||
}
|
||||
return inputs.front();
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
22
mlx/einsum.h
Normal file
22
mlx/einsum.h
Normal file
|
@ -0,0 +1,22 @@
|
|||
// Copyright © 2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::pair<std::vector<std::vector<int>>, std::string> einsum_path(
|
||||
const std::string& subscripts,
|
||||
const std::vector<array>& operands);
|
||||
|
||||
array einsum(
|
||||
const std::string& subscripts,
|
||||
const std::vector<array>& operands,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core
|
|
@ -8,6 +8,7 @@
|
|||
#include "mlx/device.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/einsum.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/io.h"
|
||||
|
|
|
@ -444,8 +444,9 @@ array laplace(
|
|||
auto samples = uniform(low, high, shape, dtype, key, stream);
|
||||
// Use inverse CDF to generate Laplacian noise
|
||||
samples = multiply(
|
||||
sign(samples),
|
||||
log1p(multiply(array(-1.0f, dtype), abs(samples))),
|
||||
sign(samples, stream),
|
||||
log1p(
|
||||
multiply(array(-1.0f, dtype), abs(samples, stream), stream), stream),
|
||||
stream);
|
||||
|
||||
if (scale != 1.0) {
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include "mlx/einsum.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "python/src/load.h"
|
||||
|
@ -40,15 +41,6 @@ double scalar_to_double(Scalar s) {
|
|||
}
|
||||
|
||||
void init_ops(nb::module_& m) {
|
||||
// TODO, remove deprecation errors in a future release
|
||||
m.def("block_sparse_mm", [](nb::args, nb::kwargs) {
|
||||
throw std::invalid_argument(
|
||||
"block_sparse_mm is deprecated. Please use gather_mm which has the same signature");
|
||||
});
|
||||
m.def("block_sparse_qmm", [](nb::args, nb::kwargs) {
|
||||
throw std::invalid_argument(
|
||||
"block_sparse_qmm is deprecated. Please use gather_qmm which has the same signature");
|
||||
});
|
||||
m.def(
|
||||
"reshape",
|
||||
&reshape,
|
||||
|
@ -1238,7 +1230,8 @@ void init_ops(nb::module_& m) {
|
|||
a (array): Input array.
|
||||
|
||||
Returns:
|
||||
array: The unchanged input ``a`` but without gradient flowing
|
||||
array:
|
||||
The unchanged input ``a`` but without gradient flowing
|
||||
through it.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
|
@ -2936,6 +2929,9 @@ void init_ops(nb::module_& m) {
|
|||
reverse (bool): Perform the cumulative sum in reverse.
|
||||
inclusive (bool): The i-th element of the output includes the i-th
|
||||
element of the input.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"cumprod",
|
||||
|
@ -2969,6 +2965,9 @@ void init_ops(nb::module_& m) {
|
|||
reverse (bool): Perform the cumulative product in reverse.
|
||||
inclusive (bool): The i-th element of the output includes the i-th
|
||||
element of the input.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"cummax",
|
||||
|
@ -3002,6 +3001,9 @@ void init_ops(nb::module_& m) {
|
|||
reverse (bool): Perform the cumulative maximum in reverse.
|
||||
inclusive (bool): The i-th element of the output includes the i-th
|
||||
element of the input.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"cummin",
|
||||
|
@ -3035,6 +3037,9 @@ void init_ops(nb::module_& m) {
|
|||
reverse (bool): Perform the cumulative minimum in reverse.
|
||||
inclusive (bool): The i-th element of the output includes the i-th
|
||||
element of the input.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conj",
|
||||
|
@ -3052,6 +3057,9 @@ void init_ops(nb::module_& m) {
|
|||
|
||||
Args:
|
||||
a (array): Input array
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conjugate",
|
||||
|
@ -3069,6 +3077,9 @@ void init_ops(nb::module_& m) {
|
|||
|
||||
Args:
|
||||
a (array): Input array
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"convolve",
|
||||
|
@ -3492,14 +3503,11 @@ void init_ops(nb::module_& m) {
|
|||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
format (str, optional): Format of the file. If ``None``, the
|
||||
format
|
||||
is inferred from the file extension. Supported formats:
|
||||
``npy``,
|
||||
``npz``, and ``safetensors``. Default: ``None``.
|
||||
format is inferred from the file extension. Supported formats:
|
||||
``npy``, ``npz``, and ``safetensors``. Default: ``None``.
|
||||
return_metadata (bool, optional): Load the metadata for formats
|
||||
which
|
||||
support matadata. The metadata will be returned as an
|
||||
additional dictionary.
|
||||
which support matadata. The metadata will be returned as an
|
||||
additional dictionary. Default: ``False``.
|
||||
Returns:
|
||||
array or dict:
|
||||
A single array if loading from a ``.npy`` file or a dict
|
||||
|
@ -3551,9 +3559,9 @@ void init_ops(nb::module_& m) {
|
|||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
arrays (dict(str, array)): The dictionary of names to arrays to
|
||||
be saved. metadata (dict(str, Union[array, str, list(str)])):
|
||||
The dictionary of
|
||||
metadata to be saved. The values can be a scalar or 1D
|
||||
be saved.
|
||||
metadata (dict(str, Union[array, str, list(str)])): The dictionary
|
||||
of metadata to be saved. The values can be a scalar or 1D
|
||||
obj:`array`, a :obj:`str`, or a :obj:`list` of :obj:`str`.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
|
@ -3643,11 +3651,11 @@ void init_ops(nb::module_& m) {
|
|||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
transpose (bool, optional): Defines whether to multiply with the
|
||||
transposed ``w`` or not, namely whether we are performing
|
||||
``x @ w.T`` or ``x @ w``. (default: ``True``)
|
||||
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||
group_size (int, optional): The size of the group in ``w`` that
|
||||
shares a scale and bias. (default: ``64``)
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. (default: ``4``)
|
||||
``w``. Default: ``4``.
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``.
|
||||
|
@ -3700,9 +3708,9 @@ void init_ops(nb::module_& m) {
|
|||
Args:
|
||||
w (array): Matrix to be quantized
|
||||
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. (default: ``64``)
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. (default: ``4``)
|
||||
``w`` in the returned quantized matrix. Default: ``4``.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing
|
||||
|
@ -3740,9 +3748,9 @@ void init_ops(nb::module_& m) {
|
|||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. (default: ``64``)
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. (default: ``4``)
|
||||
``w``. Default: ``4``.
|
||||
|
||||
Returns:
|
||||
array: The dequantized version of ``w``
|
||||
|
@ -3779,15 +3787,15 @@ void init_ops(nb::module_& m) {
|
|||
w (array): Quantized matrix packed in unsigned integers
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
lhs_indices (array, optional): Integer indices for ``x`` (default: ``None``)
|
||||
rhs_indices (array, optional): Integer indices for ``w`` (default: ``None``)
|
||||
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
|
||||
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
|
||||
transpose (bool, optional): Defines whether to multiply with the
|
||||
transposed ``w`` or not, namely whether we are performing
|
||||
``x @ w.T`` or ``x @ w``. (default: ``True``)
|
||||
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||
group_size (int, optional): The size of the group in ``w`` that
|
||||
shares a scale and bias. (default: ``64``)
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. (default: ``4``)
|
||||
``w``. Default: ``4``.
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``
|
||||
|
@ -3827,7 +3835,7 @@ void init_ops(nb::module_& m) {
|
|||
sum over. If an integer is provided, then sum over the last
|
||||
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
|
||||
``b``. If a list of lists is provided, then sum over the
|
||||
corresponding dimensions of ``a`` and ``b``. (default: 2)
|
||||
corresponding dimensions of ``a`` and ``b``. Default: 2.
|
||||
|
||||
Returns:
|
||||
array: The tensor dot product.
|
||||
|
@ -3958,11 +3966,13 @@ void init_ops(nb::module_& m) {
|
|||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64`` (default: ``64``)
|
||||
mask_out (array, optional): Mask for output (default: ``None``)
|
||||
mask_lhs (array, optional): Mask for a (default: ``None``)
|
||||
mask_rhs (array, optional): Mask for b (default: ``None``)
|
||||
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64``. Default: ``64``.
|
||||
mask_out (array, optional): Mask for output. Default: ``None``.
|
||||
mask_lhs (array, optional): Mask for ``a``. Default: ``None``.
|
||||
mask_rhs (array, optional): Mask for ``b``. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"gather_mm",
|
||||
|
@ -3996,9 +4006,11 @@ void init_ops(nb::module_& m) {
|
|||
Args:
|
||||
a (array): Input array.
|
||||
b (array): Input array.
|
||||
lhs_indices (array, optional): Integer indices for ``a`` (default: ``None``)
|
||||
rhs_indices (array, optional): Integer indices for ``b`` (default: ``None``)
|
||||
lhs_indices (array, optional): Integer indices for ``a``. Default: ``None``
|
||||
rhs_indices (array, optional): Integer indices for ``b``. Default: ``None``
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"diagonal",
|
||||
|
@ -4406,4 +4418,57 @@ void init_ops(nb::module_& m) {
|
|||
Returns:
|
||||
array: The transformed array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"einsum_path",
|
||||
[](const std::string& equation, const nb::args& operands) {
|
||||
auto arrays_list = nb::cast<std::vector<array>>(operands);
|
||||
auto [path, str] = einsum_path(equation, arrays_list);
|
||||
// Convert to list of tuples
|
||||
std::vector<nb::tuple> tuple_path;
|
||||
for (auto& p : path) {
|
||||
tuple_path.push_back(nb::tuple(nb::cast(p)));
|
||||
}
|
||||
return std::make_pair(tuple_path, str);
|
||||
},
|
||||
"subscripts"_a,
|
||||
"operands"_a,
|
||||
nb::sig("def einsum_path(subscripts: str, *operands)"),
|
||||
R"pbdoc(
|
||||
|
||||
Compute the contraction order for the given Einstein summation.
|
||||
|
||||
Args:
|
||||
subscripts (str): The Einstein summation convention equation.
|
||||
*operands (array): The input arrays.
|
||||
|
||||
Returns:
|
||||
tuple(list(tuple(int, int)), str):
|
||||
The einsum path and a string containing information about the
|
||||
chosen path.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"einsum",
|
||||
[](const std::string& subscripts,
|
||||
const nb::args& operands,
|
||||
StreamOrDevice s) {
|
||||
auto arrays_list = nb::cast<std::vector<array>>(operands);
|
||||
return einsum(subscripts, arrays_list, s);
|
||||
},
|
||||
"subscripts"_a,
|
||||
"operands"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def einsum(subscripts: str, *operands, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
|
||||
Perform the Einstein summation convention on the operands.
|
||||
|
||||
Args:
|
||||
subscripts (str): The Einstein summation convention equation.
|
||||
*operands (array): The input arrays.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
} |