Skip to content

Fix np bfloat16 misinterpreted as complex#3146

Open
kellen-sun wants to merge 11 commits intoml-explore:mainfrom
kellen-sun:fix-bfloat16-misinterpreted
Open

Fix np bfloat16 misinterpreted as complex#3146
kellen-sun wants to merge 11 commits intoml-explore:mainfrom
kellen-sun:fix-bfloat16-misinterpreted

Conversation

@kellen-sun
Copy link
Contributor

@kellen-sun kellen-sun commented Feb 19, 2026

Proposed changes

Fixes #1075
Bug: The bug happens when converting np.array(1., dtype=ml_dtypes.bfloat16) and np.array([1.], dtype=ml_dtypes.bfloat16) to mx.array(x). For the former case, it'll silently be caught as a std::complex as part of ArrayInitType and get converted as such (see related issue for code). For the latter, it'll be interpreted as an ArrayLike, not be able to make the conversion to mx.array() and raise a ValueError.
The Fix: We need to catch this case before it gets filtered by ArrayInitType. I made the array.__init__ more generic to catch this and checked the dtype to match bfloat16, then manually construct the array. Otherwise, we fallback to the original ArrayInitType case.
Note: bfloat16, is the only current ml_dtype that mlx supports.
Verification: Verified locally (macOS 26.2, MLX 0.30.7, Apple M2) with the additional test case I provided. If this is run from main, it raises the bug mentioned above:

    @unittest.skipIf(not has_ml_dtypes, "requires ml_dtypes")
    def test_conversion_ml_dtypes(self):
        x_scalar = np.array(1.5, dtype=ml_dtypes.bfloat16)
        a_scalar = mx.array(x_scalar)
>       self.assertEqual(a_scalar.dtype, mx.bfloat16)
E       AssertionError: mlx.core.complex64 != mlx.core.bfloat16

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Comment on lines 300 to 326
if (nb::hasattr(v, "dtype")) {
if (nb::str(v.attr("dtype")).equal(nb::str("bfloat16"))) {
auto type_mod = nb::str(v.attr("__class__").attr("__module__"));
if (type_mod.equal(nb::str("numpy")) ||
type_mod.equal(nb::str("ml_dtypes"))) {
auto np = nb::module_::import_("numpy");
auto contig_obj = np.attr("ascontiguousarray")(v);
mx::Shape shape;
nb::tuple shape_tuple = nb::cast<nb::tuple>(v.attr("shape"));
size_t ndim = shape_tuple.size();
for (size_t i = 0; i < ndim; ++i) {
shape.push_back(nb::cast<int>(shape_tuple[i]));
}
uint64_t ptr_int = nb::cast<uint64_t>(
contig_obj.attr("ctypes").attr("data"));
const mx::bfloat16_t* typed_ptr =
reinterpret_cast<const mx::bfloat16_t*>(ptr_int);
auto res = (ndim == 0)
? mx::array(*typed_ptr, mx::bfloat16)
: mx::array(typed_ptr, shape, mx::bfloat16);
if (t.has_value())
res = mx::astype(res, *t);
new (aptr) mx::array(res);
return;
}
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should really be done in the create_array function. That way other conversions from numpy bfloat16 will work.

Copy link
Contributor Author

@kellen-sun kellen-sun Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can probably move the code to create_array, but the fix does require capturing a general nb::object from init because of how the bfloat16 falls into complex, so we'll need to change the signature on create_array as well and how it's called. If so, I've added a fix, with an additional test case (the other valid use for create_array with bfloat16).

Here's the output of that case if run from main:

>      a_asarray = mx.asarray(x_vector)
                    ^^^^^^^^^^^^^^^^^^^^
E       ValueError: Invalid type ndarray received in array initialization.

python/tests/test_bf16.py:221: ValueError

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] np.ndarray of bfloat16 using ml_dtypes is being interpreted as complex64

2 participants

Comments