Skip to content

MLX backend POC #1365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 68 commits into
base: main
Choose a base branch
from
Open

MLX backend POC #1365

wants to merge 68 commits into from

Conversation

williambdean
Copy link
Contributor

@williambdean williambdean commented Apr 11, 2025

Description

Getting ball rolling started with #1350

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1365.org.readthedocs.build/en/1365/

@williambdean williambdean marked this pull request as draft April 11, 2025 15:43
@williambdean williambdean marked this pull request as ready for review April 11, 2025 17:54
Copy link

codecov bot commented Apr 11, 2025

Codecov Report

Attention: Patch coverage is 80.23256% with 17 lines in your changes missing coverage. Please review.

Project coverage is 82.04%. Comparing base (4e59f21) to head (d057453).
Report is 13 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/mlx/linker.py 77.50% 6 Missing and 3 partials ⚠️
pytensor/link/mlx/dispatch/basic.py 80.64% 6 Missing ⚠️
pytensor/compile/mode.py 50.00% 1 Missing and 1 partial ⚠️

❌ Your patch check has failed because the patch coverage (80.23%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1365      +/-   ##
==========================================
+ Coverage   82.02%   82.04%   +0.01%     
==========================================
  Files         203      208       +5     
  Lines       48845    48949     +104     
  Branches     8691     8701      +10     
==========================================
+ Hits        40067    40162      +95     
- Misses       6627     6632       +5     
- Partials     2151     2155       +4     
Files with missing lines Coverage Δ
pytensor/link/mlx/__init__.py 100.00% <100.00%> (ø)
pytensor/link/mlx/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/link/mlx/dispatch/math.py 100.00% <100.00%> (ø)
pytensor/link/pytorch/linker.py 100.00% <ø> (ø)
pytensor/compile/mode.py 84.09% <50.00%> (-0.64%) ⬇️
pytensor/link/mlx/dispatch/basic.py 80.64% <80.64%> (ø)
pytensor/link/mlx/linker.py 77.50% <77.50%> (ø)

... and 8 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94
Copy link
Member

I suggest basing yourself on the numba linker, torch has a lot of hacks we hopefully don't need here

@williambdean
Copy link
Contributor Author

williambdean commented Apr 12, 2025

Thanks for the pointer. I simplified the one method. Do you think that gen_functors can be removed as well? The only commonality with pytorch then is that no input can be numpy array.

@ricardoV94
Copy link
Member

ricardoV94 commented Apr 13, 2025

Yeah you shouldn't need that you just need a call to tipify on the runtime inputs as well



@mlx_typify.register(np.ndarray)
@mlx_typify.register(mx.array)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mxarray should be registered in mlx_typify_no_conversion_needed

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment on lines +67 to +78
@mlx_funcify.register(Assert)
@mlx_funcify.register(CheckAndRaise)
def mlx_funcify_CheckAndRaise(op, **kwargs):
warnings.warn(
f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as MLX tracing would remove it.""",
stacklevel=2,
)

def assert_fn(x, *inputs):
return x

return assert_fn
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true, or just copy/pasta from JAX?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to check more here!

from pytensor.tensor.signal.conv import Conv1d


def blockwise_conv1d(op, node, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed anymore since they fixed upstream right?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we needed still. Where do you see its fixed? We are using this blockwise conv1d.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure but blockwise will call vmap on the core op, so we only need to dispatch core Conv1D to MLX Conv1D, then the blockwise variant will work automatically

Comment on lines 37 to 39
# ------------------------------------------------------------------
# Join
# ------------------------------------------------------------------
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't introduce comments like this, the code is readable and won't get stale if we move things around

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

# ------------------------------------------------------------------
# Join
# ------------------------------------------------------------------
@mlx_funcify.register(Join) # MLX
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The multiple #MLX comments are useless

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment on lines +21 to +25
# Convert scalar to array if needed
if isinstance(x, int | float) or (
isinstance(x, np.number) and not isinstance(x, np.ndarray)
):
x = mx.array(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should not be needed

Comment on lines 35 to 62
@mlx_funcify.register(CAReduce)
def mlx_funcify_CAReduce(op, **kwargs):
if isinstance(op.scalar_op, Add):

def sum(x):
return mx.sum(x, axis=op.axis)

return sum
elif isinstance(op.scalar_op, Mul):

def prod(x):
return mx.prod(x, axis=op.axis)

return prod
elif isinstance(op.scalar_op, AND):

def all(x):
return x.all(axis=op.axis)

return all
elif isinstance(op.scalar_op, OR):

def any(x):
return mx.any(x, axis=op.axis)

return any
else:
raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should do a secon-lever dispatch on the core_op. Something like MLX_funcify_CAREduce that is called on op.scalar_op

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!



@mlx_funcify.register(Elemwise)
def mlx_funcify_Elemwise(op, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like CAReduce it should have a second level dispatch. Also we need to enforce the runtime_broadcastable checks (same in Alloc). And we shoud have a default implementation for that second level dispatch that tries to use getattr(MLX, "func_name") similar to how JAX does it already.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment on lines +81 to +84
if not op.inplace:
x = deepcopy(x)
x[indices] = y
return x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need tests for all these including inplace variants

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment on lines +61 to +69
for n in self.fgraph.inputs:
sinput = storage_map[n]
# Handle random number generators specially
if isinstance(sinput[0], RandomState | Generator):
new_value = mlx_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None)
)
sinput[0] = new_value
thunk_inputs.append(sinput)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we don't have Random stuff yet we shouldn't include the code

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay added some!

Comment on lines +78 to +80
# 1) If it's a Conv1d Blockwise, use the custom implementation
if isinstance(op.core_op, Conv1d):
return blockwise_conv1d(op, node, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we don't need this special casing anymore

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure? Without it the test fails.

def mlx_funcify_Join(op, **kwargs):
def join(axis, *tensors):
view = op.view
if (view != -1) and all(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

view was removed from PyTensor, ignore

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment on lines 43 to 46
@mlx_typify.register(int)
@mlx_typify.register(float)
def mlx_typify_python_scalar(data, **kwargs):
return mx.array(data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put this close to the other mlx_typipy

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

def elemwise(*inputs):
# Enforce runtime broadcast checks (same as JAX and PyTorch implementations)
if node is not None:
# Convert inputs to MLX arrays for broadcast checking
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

node is never none, conversion to mx.array() shouldn't be needed either

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

from tests.link.mlx.test_basic import compare_mlx_and_py


@pytest.mark.xfail(reason="Shape Op is not supported yet")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it is?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

return min_reduce


@mlx_funcify.register(CAReduce)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put this before the specific scalar_op dispatches?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

return softmax_grad


@mlx_funcify.register(Softplus)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete this? You have one in elemwise already

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@williambdean
Copy link
Contributor Author

Were we going to split this PR up into core functionality and op implementations? What are the next steps here?

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 3, 2025

This PR seems okay state that we can merge as is when the comments are addressed

@cetagostini
Copy link

@ricardoV94 I think all it's applied.

@@ -177,7 +177,7 @@ def tensor_from_scalar(x):
@mlx_funcify.register(ScalarFromTensor)
def mlx_funcify_ScalarFromTensor(op, **kwargs):
def scalar_from_tensor(x):
return x.reshape(-1)[0]
return mx.array(x).reshape(-1)[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems convoluted. Does MLX have something like x.item()?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to check this again. One minute.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has but we have issues with the pytensor-mlx way to compile, and this make it generic enough.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize to have both options, hope you like that more!

result = scalar_from_tensor_func(mlx_array)
assert result == 42

# Test with Python int (this used to fail)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are python ints being passed? That suggests a bug elsewhere

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, not bug but didn't replace a few things, done!

scalar_result = pt.scalar_from_tensor(x)

# Create function and test
f = pytensor.function([], scalar_result, mode="MLX")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't test MLX at all because it's a constant function and you are running full optimization, so it will just be constant folded. Instead make x a symbolic variable like x = pytensor.scalar.int64(x) that is an input to the function.

Copy link

@cetagostini cetagostini Jun 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally right, miss that 🙌🏻

@@ -11,3 +12,39 @@ def test_input(op) -> None:
x_test = mx.array([1.0, 2.0, 3.0])

compare_mlx_and_py([x], out, [x_test])


def test_new_elemwise_operations() -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad test name, they won't be new by the time this PR is merged ;)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

def test_mlx_Reshape_constant():
a = vector("a")
x = reshape(a, (2, 2))
compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])


@pytest.mark.xfail(reason="Reshape Op is not supported yet")
def test_mlx_Reshape_various_shapes():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this similar to test_mlx_Reshape_concrete below? Combine with that?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, doesn't seem appropriate in this case, as they're testing different scenarios of the reshape operation.

  1. test_mlx_Reshape_various_shapes focuses on testing different dimensional transformations with static/constant shapes.
  2. test_mlx_Reshape_concrete_shape focuses on testing computed/dynamic shapes where the shape is derived from the input tensor's properties.

Maybe they can be rename? But I feel two different things!

@cetagostini
Copy link

cetagostini commented Jun 9, 2025

Current implementation allow to sample a simple pymc-marketing model, both gpu and cpu with MLX backend. Nevertheless complex model got issues still.

image

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@cetagostini
Copy link

Adding small implementation to test MLX speed against jax.

Screenshot 2025-06-09 at 19 45 48

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants