fix creating array from bf16 tensors in jax / torch (#1305)

This commit is contained in:
Awni Hannun 2024-08-01 16:20:51 -07:00 committed by GitHub
parent 6c8dd307eb
commit 10b5835501
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 9 additions and 10 deletions

View file

@ -24,15 +24,6 @@ struct ndarray_traits<float16_t> {
static constexpr bool is_signed = true;
};
template <>
struct ndarray_traits<bfloat16_t> {
static constexpr bool is_complex = false;
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};
static constexpr dlpack::dtype bfloat16{4, 16, 1};
}; // namespace nanobind
@ -88,7 +79,7 @@ array nd_array_to_mlx(
} else if (type == nb::dtype<float16_t>()) {
return nd_array_to_mlx_contiguous<float16_t>(
nd_array, shape, dtype.value_or(float16));
} else if (type == nb::dtype<bfloat16_t>()) {
} else if (type == nb::bfloat16) {
return nd_array_to_mlx_contiguous<bfloat16_t>(
nd_array, shape, dtype.value_or(bfloat16));
} else if (type == nb::dtype<float>()) {

View file

@ -183,6 +183,14 @@ class TestBF16(mlx_tests.MLXTestCase):
]:
test_blas(shape_x, shape_y)
@unittest.skipIf(not has_torch, "requires PyTorch")
def test_conversion(self):
a_torch = torch.tensor([1.0, 2.0, 3.0], dtype=torch.bfloat16)
a_mx = mx.array(a_torch)
expected = mx.array([1.0, 2.0, 3.0], mx.bfloat16)
self.assertEqual(a_mx.dtype, mx.bfloat16)
self.assertTrue(mx.array_equal(a_mx, expected))
if __name__ == "__main__":
unittest.main()