Skip to content

Allow per-variable optimizer, add DispatchOptimizer. #21196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

cantonios
Copy link
Contributor

  • Adds a property variable.optimizer that defaults to None
  • Adds a DispatchOptimizer that scans the list of trainable variables during build, collects all unique per-variable optimizers, then dispatches the apply/stateless_apply function to the correct optimizer if applicable.
  • Modifies trainer so that during the optimizer build stage, checks if any variables have a custom optimizer attached, and if so inserts a DispatchOptimizer to properly handle them. This allows usage to be hidden from the user.

Context: for large embedding tables, we need special optimizers to be used so that the tables can be updated in-place, rather than returning large gradients. The layer will handle setting of the custom optimizers, but we need the trainer to be aware of them and dispatch the embedding tables to different optimizers appropriately.

@codecov-commenter
Copy link

codecov-commenter commented Apr 21, 2025

Codecov Report

Attention: Patch coverage is 75.26316% with 47 lines in your changes missing coverage. Please review.

Project coverage is 61.64%. Comparing base (8a6e83b) to head (c289f80).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/optimizers/dispatch_optimizer.py 84.13% 17 Missing and 6 partials ⚠️
keras/src/backend/tensorflow/optimizer.py 0.00% 13 Missing ⚠️
keras/src/trainers/trainer.py 40.00% 7 Missing and 2 partials ⚠️
keras/src/optimizers/base_optimizer.py 50.00% 1 Missing and 1 partial ⚠️

❗ There is a different number of reports uploaded between BASE (8a6e83b) and HEAD (c289f80). Click for more details.

HEAD has 6 uploads less than BASE
Flag BASE (8a6e83b) HEAD (c289f80)
keras 5 2
keras-tensorflow 1 0
keras-torch 1 0
keras-jax 1 0
Additional details and impacted files
@@             Coverage Diff             @@
##           master   #21196       +/-   ##
===========================================
- Coverage   82.59%   61.64%   -20.95%     
===========================================
  Files         564      565        +1     
  Lines       54408    54592      +184     
  Branches     8449     8487       +38     
===========================================
- Hits        44937    33652    -11285     
- Misses       7396    18860    +11464     
- Partials     2075     2080        +5     
Flag Coverage Δ
keras 61.63% <75.26%> (-20.77%) ⬇️
keras-jax ?
keras-numpy 58.99% <75.26%> (+0.06%) ⬆️
keras-openvino 32.90% <20.00%> (-0.05%) ⬇️
keras-tensorflow ?
keras-torch ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

- Adds a property `variable.optimizer` that defaults to `None`
- Adds a `DispatchOptimizer` that scans the list of trainable variables during build,
  collects all unique per-variable optimizers, then dispatches the apply/stateless_apply
  function to the correct optimizer if applicable.
- Modifies `trainer` so that during the optimizer build stage, checks if any variables
  have a custom optimizer attached, and if so inserts a `DispatchOptimizer` to properly
  handle them.  This allows usage to be hidden from the user.

Context: for large embedding tables, we need special optimizers to be used so that
the tables can be updated in-place, rather than returning large gradients.  The layer
will handle setting of the custom optimizers, but we need the trainer to be aware
of them and dispatch the embedding tables to different optimizers appropriately.
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Could we avoid the changes in trainer.py and instead limit the impact to the optimizer? The optimizer can check whether its variables has an overridden optimizer, for instance.

We could also require the user to pass a custom optimizer to compile() for distributed embeddings to work. Maybe that's better?

@cantonios
Copy link
Contributor Author

Could we avoid the changes in trainer.py and instead limit the impact to the optimizer? The optimizer can check whether its variables has an overridden optimizer, for instance.

The optimizer can't replace itself, so we can't "insert" a DispatchOptimizer when it's needed from within the optimizer itself.

I was originally going to modify BaseOptimizer to do the dispatching. The main problem is that optimizers create their own weights per variable, like SGD creating momentum variables here. We don't want these to be created if the optimizer itself doesn't handle the variable - in the case of large embedding tables, these variables don't even fit on a single device. We can change this behavior, but it's a larger refactor that requires updating all existing optimizers, and might break any custom optimizers out in the wild. Is it worth pursuing this direction?

We could also require the user to pass a custom optimizer to compile() for distributed embeddings to work. Maybe that's better?

Yes, we could. It requires the user to know about it. The base optimizer could throw an error if it encounters a variable with an optimizer attached, telling the user to use a DispatchOptimizer instead. The changes to trainer were to try to do this automatically.

@fchollet
Copy link
Collaborator

Understood. How about:

  • Add DistributedEmbeddingOptimizer which takes a base optimizer but is aware of distributed embedding variables?
  • In the base optimizer, add a line somewhere to error out if a distributed embedding variable is used with an optimizer that is not a DistributedEmbeddingOptimizer

If possible at all we should also consider just automating the above -- when a model gets compiled, we check if is has distributed embeddings, and if so we replace the optimizer with a DistributedEmbeddingOptimizer. What do you think?

@cantonios
Copy link
Contributor Author

cantonios commented Apr 23, 2025

How about:

  • Add DistributedEmbeddingOptimizer which takes a base optimizer but is aware of distributed embedding variables?
  • In the base optimizer, add a line somewhere to error out if a distributed embedding variable is used with an optimizer that is not a DistributedEmbeddingOptimizer

It's possible, but I wanted to keep it a bit more general than that since there are other contexts in which we need to treat specific variables differently. For example, overwrite_with_gradient looks like it was added for similar reasons: special handling of gradients/updates for specific kinds of variables. This could instead be handled with a per-variable "optimizer" rather than a special property.

We also don't have a concept of "distributed embedding" in core Keras, only Keras RS.

If possible at all we should also consider just automating the above -- when a model gets compiled, we check if is has distributed embeddings, and if so we replace the optimizer with a DistributedEmbeddingOptimizer. What do you think?

This is exactly what I tried in this PR - except we can't detect "distributed embeddings" in compile, because we don't yet have access to the full set of trainable variables - we need to do it when we're symbolically building the model. We don't have a DistributedEmbeddingLayer in core Keras, so we can't detect based on the model's layer's either - unless we add an attribute to the base layer to mark it as needing special handling.


Thinking out loud:

Maybe "optimizer" is the wrong word - what we really need is a method to treat "gradients" as auxiliary data rather than true gradients, and allow optimizers to dispatch the variable update to a custom "updater" that takes in that auxiliary data. We also don't want that auxiliary data modified in any way - i.e. avoid scaling by loss-scaling, learning rates. And we don't want optimizers to create their own extra optimizer-variables for these either.

Edit: though we may need the "iteration" (or step count) from the optimizer, since the "updater" may need this information from the optimizer to update internal state - unless we track that ourselves for every update call.

@fchollet
Copy link
Collaborator

This could instead be handled with a per-variable "optimizer" rather than a special property.

I like the generality, we can definitely add a generic variable attribute for this. Is "optimizer" the right abstraction though?

Two issues I see:

  • It creates circular references (optimizers own variables and variables may own an optimizer)
  • It feels like a layering violation (optimizers are very high level, variables are low level)

@cantonios
Copy link
Contributor Author

Two issues I see:

  • It creates circular references (optimizers own variables and variables may own an optimizer)
  • It feels like a layering violation (optimizers are very high level, variables are low level)

I agree. It's not really that the variable owns the optimizer - this is just the most convenient way I could think of to attach the information. There seemed to be precedence with that overwrite_with_gradient attribute.

It's the layer that knows that certain variables that it owns need special handling. In this case, the layer knows that these large embedding tables can't use a traditional optimizer. The layer needs some way to tell the model how to handle them, and specifically that the table variables need special optimizers. Adding the optimizer as an attribute on the variable was the communication mechanism for this.

The other way this could be done is with a map variable.path --> optimizer. We would still need a way for the layer to tell the model that we have these special requirements. I suppose there could be a

layer.variable_path_to_optimizer_map: dict[str, keras.optimizers.Optimizer]

that the model could query in order to build up a set of optimizers. If empty or None, it would use the traditional optimizer, otherwise it would use a DispatchOptimizer built from the combined map from all layers.

@mattdangerw
Copy link
Member

It's the layer that knows that certain variables that it owns need special handling. In this case, the layer knows that these large embedding tables can't use a traditional optimizer. The layer needs some way to tell the model how to handle them, and specifically that the table variables need special optimizers.

Could consider some design where this is handled at a custom layer level? Go for a similar vibe to add_loss https://keras.io/api/losses/#the-addloss-api where a optimizer is attached during build or init. Or give a custom layer the ability to define it's own custom apply step somehow.

No idea if these are good ideas :). Just drive by thoughts.

@cantonios
Copy link
Contributor Author

Could consider some design where this is handled at a custom layer level? Go for a similar vibe to add_loss https://keras.io/api/losses/#the-addloss-api where a optimizer is attached during build or init. Or give a custom layer the ability to define it's own custom apply step somehow.

We need to mark the variable as trainable so its gradient is computed, and we need that gradient for performing the update, but we want to skip everything else that the optimizers do, including allocating things like momentum or gradient accumulator variables.

There apparently used to be an add_update method that has since been deprecated. If we make our weight non-trainable (so it can be excluded by the trainer and hence main optimizer), but be able to mark a non-trainable weight as still needing a gradient computation, as well as being able to attach a custom variable updater - that would also work.

It's all a lot more complicated though than simply allowing us to specify a custom optimizer for specific variables, which essentially accomplishes the same thing.

@fchollet
Copy link
Collaborator

IMO...

  • No changes to base trainer, and if possible base variable
  • A custom optimizer can handle the differentiated update based on the kind of variable we have (could be based on some attribute on the variable instance)
  • We can find some way to enforce that the custom optimizer is always used in conjunction with the distributed embedding layer (so that there can be no user error) -- might be fine to special case the base optimizer to check for a certain variable attribute

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants