-
Notifications
You must be signed in to change notification settings - Fork 135
MLX backend POC #1365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
MLX backend POC #1365
Conversation
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (80.23%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1365 +/- ##
==========================================
+ Coverage 82.02% 82.04% +0.01%
==========================================
Files 203 208 +5
Lines 48845 48949 +104
Branches 8691 8701 +10
==========================================
+ Hits 40067 40162 +95
- Misses 6627 6632 +5
- Partials 2151 2155 +4
🚀 New features to boost your workflow:
|
I suggest basing yourself on the numba linker, torch has a lot of hacks we hopefully don't need here |
Thanks for the pointer. I simplified the one method. Do you think that |
Yeah you shouldn't need that you just need a call to tipify on the runtime inputs as well |
pytensor/link/mlx/dispatch/basic.py
Outdated
|
||
|
||
@mlx_typify.register(np.ndarray) | ||
@mlx_typify.register(mx.array) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mxarray
should be registered in mlx_typify_no_conversion_needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
@mlx_funcify.register(Assert) | ||
@mlx_funcify.register(CheckAndRaise) | ||
def mlx_funcify_CheckAndRaise(op, **kwargs): | ||
warnings.warn( | ||
f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as MLX tracing would remove it.""", | ||
stacklevel=2, | ||
) | ||
|
||
def assert_fn(x, *inputs): | ||
return x | ||
|
||
return assert_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this true, or just copy/pasta from JAX?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to check more here!
from pytensor.tensor.signal.conv import Conv1d | ||
|
||
|
||
def blockwise_conv1d(op, node, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed anymore since they fixed upstream right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we needed still. Where do you see its fixed? We are using this blockwise conv1d.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure but blockwise will call vmap on the core op, so we only need to dispatch core Conv1D to MLX Conv1D, then the blockwise variant will work automatically
pytensor/link/mlx/dispatch/core.py
Outdated
# ------------------------------------------------------------------ | ||
# Join | ||
# ------------------------------------------------------------------ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't introduce comments like this, the code is readable and won't get stale if we move things around
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
pytensor/link/mlx/dispatch/core.py
Outdated
# ------------------------------------------------------------------ | ||
# Join | ||
# ------------------------------------------------------------------ | ||
@mlx_funcify.register(Join) # MLX |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The multiple #MLX comments are useless
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
# Convert scalar to array if needed | ||
if isinstance(x, int | float) or ( | ||
isinstance(x, np.number) and not isinstance(x, np.ndarray) | ||
): | ||
x = mx.array(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should not be needed
@mlx_funcify.register(CAReduce) | ||
def mlx_funcify_CAReduce(op, **kwargs): | ||
if isinstance(op.scalar_op, Add): | ||
|
||
def sum(x): | ||
return mx.sum(x, axis=op.axis) | ||
|
||
return sum | ||
elif isinstance(op.scalar_op, Mul): | ||
|
||
def prod(x): | ||
return mx.prod(x, axis=op.axis) | ||
|
||
return prod | ||
elif isinstance(op.scalar_op, AND): | ||
|
||
def all(x): | ||
return x.all(axis=op.axis) | ||
|
||
return all | ||
elif isinstance(op.scalar_op, OR): | ||
|
||
def any(x): | ||
return mx.any(x, axis=op.axis) | ||
|
||
return any | ||
else: | ||
raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should do a secon-lever dispatch on the core_op. Something like MLX_funcify_CAREduce
that is called on op.scalar_op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
pytensor/link/mlx/dispatch/math.py
Outdated
|
||
|
||
@mlx_funcify.register(Elemwise) | ||
def mlx_funcify_Elemwise(op, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like CAReduce it should have a second level dispatch. Also we need to enforce the runtime_broadcastable
checks (same in Alloc). And we shoud have a default implementation for that second level dispatch that tries to use getattr(MLX, "func_name")
similar to how JAX does it already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
if not op.inplace: | ||
x = deepcopy(x) | ||
x[indices] = y | ||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need tests for all these including inplace variants
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
for n in self.fgraph.inputs: | ||
sinput = storage_map[n] | ||
# Handle random number generators specially | ||
if isinstance(sinput[0], RandomState | Generator): | ||
new_value = mlx_typify( | ||
sinput[0], dtype=getattr(sinput[0], "dtype", None) | ||
) | ||
sinput[0] = new_value | ||
thunk_inputs.append(sinput) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we don't have Random stuff yet we shouldn't include the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay added some!
# 1) If it's a Conv1d Blockwise, use the custom implementation | ||
if isinstance(op.core_op, Conv1d): | ||
return blockwise_conv1d(op, node, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, we don't need this special casing anymore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure? Without it the test fails.
pytensor/link/mlx/dispatch/core.py
Outdated
def mlx_funcify_Join(op, **kwargs): | ||
def join(axis, *tensors): | ||
view = op.view | ||
if (view != -1) and all( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
view was removed from PyTensor, ignore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
pytensor/link/mlx/dispatch/math.py
Outdated
@mlx_typify.register(int) | ||
@mlx_typify.register(float) | ||
def mlx_typify_python_scalar(data, **kwargs): | ||
return mx.array(data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put this close to the other mlx_typipy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
pytensor/link/mlx/dispatch/math.py
Outdated
def elemwise(*inputs): | ||
# Enforce runtime broadcast checks (same as JAX and PyTorch implementations) | ||
if node is not None: | ||
# Convert inputs to MLX arrays for broadcast checking |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
node is never none, conversion to mx.array()
shouldn't be needed either
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
tests/link/mlx/test_shape.py
Outdated
from tests.link.mlx.test_basic import compare_mlx_and_py | ||
|
||
|
||
@pytest.mark.xfail(reason="Shape Op is not supported yet") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like it is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
return min_reduce | ||
|
||
|
||
@mlx_funcify.register(CAReduce) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put this before the specific scalar_op dispatches?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
return softmax_grad | ||
|
||
|
||
@mlx_funcify.register(Softplus) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Delete this? You have one in elemwise already
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Were we going to split this PR up into core functionality and op implementations? What are the next steps here? |
This PR seems okay state that we can merge as is when the comments are addressed |
@ricardoV94 I think all it's applied. |
pytensor/link/mlx/dispatch/core.py
Outdated
@@ -177,7 +177,7 @@ def tensor_from_scalar(x): | |||
@mlx_funcify.register(ScalarFromTensor) | |||
def mlx_funcify_ScalarFromTensor(op, **kwargs): | |||
def scalar_from_tensor(x): | |||
return x.reshape(-1)[0] | |||
return mx.array(x).reshape(-1)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems convoluted. Does MLX have something like x.item()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to check this again. One minute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has but we have issues with the pytensor-mlx
way to compile, and this make it generic enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optimize to have both options, hope you like that more!
tests/link/mlx/test_basic.py
Outdated
result = scalar_from_tensor_func(mlx_array) | ||
assert result == 42 | ||
|
||
# Test with Python int (this used to fail) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are python ints being passed? That suggests a bug elsewhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, not bug but didn't replace a few things, done!
tests/link/mlx/test_basic.py
Outdated
scalar_result = pt.scalar_from_tensor(x) | ||
|
||
# Create function and test | ||
f = pytensor.function([], scalar_result, mode="MLX") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't test MLX at all because it's a constant function and you are running full optimization, so it will just be constant folded. Instead make x
a symbolic variable like x = pytensor.scalar.int64(x)
that is an input to the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally right, miss that 🙌🏻
tests/link/mlx/test_elemwise.py
Outdated
@@ -11,3 +12,39 @@ def test_input(op) -> None: | |||
x_test = mx.array([1.0, 2.0, 3.0]) | |||
|
|||
compare_mlx_and_py([x], out, [x_test]) | |||
|
|||
|
|||
def test_new_elemwise_operations() -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bad test name, they won't be new by the time this PR is merged ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
def test_mlx_Reshape_constant(): | ||
a = vector("a") | ||
x = reshape(a, (2, 2)) | ||
compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) | ||
|
||
|
||
@pytest.mark.xfail(reason="Reshape Op is not supported yet") | ||
def test_mlx_Reshape_various_shapes(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this similar to test_mlx_Reshape_concrete
below? Combine with that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean, doesn't seem appropriate in this case, as they're testing different scenarios of the reshape operation.
test_mlx_Reshape_various_shapes
focuses on testing different dimensional transformations with static/constant shapes.test_mlx_Reshape_concrete_shape
focuses on testing computed/dynamic shapes where the shape is derived from the input tensor's properties.
Maybe they can be rename? But I feel two different things!
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Description
Getting ball rolling started with #1350
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1365.org.readthedocs.build/en/1365/