Skip to content

Adding functional CompositeLayer #21099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: master
Choose a base branch
from

Conversation

martin-gorner
Copy link
Contributor

@martin-gorner martin-gorner commented Mar 27, 2025

Introduction

CompositeLayer encapsulates a functional subgraph of layers. It is the equivalent of a Functional or Sequential model but as a lighter-weight component, without the model-specific functionality like .fit() etc.

Apart from offering a useful modeling component, one of the ultimate goals of this functionality is to allow programmatic edits on pre-trained models, like adding LoRA-like weights. The current implementation hard-codes LoRA deep inside the Dense layer. This design does not allow users to change the implementation, which is already significantly behind SOTA.

☆☆☆ New demo Colab ☆☆☆  (old colab here)

The colab shows how to build an LLM functional all the way down, which means that the layer graph is editable and the Colab shows two custom parameter-efficient fine-tuning techniques (LoRA and SVF). The KV caches and text generation works, as well as training. JIT compilation too.

The code for LoRA:

  • LoRA wrapper layer with the math: 18 lines of code
  • Patching the model: 7 lines of code
  • Unpatching and folding LoRA weights back: 10 lines of code

For SVF - Singular Value fine-tuning which is an SVD-based version of LoRA:

  • SVF wrapper layer with the math: 17 lines of code
  • Patching the model: 7 lines of code
  • Unpatching and folding SVF weights back: 12 lines of code

And all the code for doing this uses user-level APIs!

API

A CompositeLayer can be created either from a list of layers or a function that defines a graph of layers. There is no
constructor similar to a functional Model(inputs, outputs) because inputs and outputs are usually not known when creating a layer. They will be created when the layer is built.

# Composite layer using a function    
def layer_fn(x):
    x = layers.Dense(64, activation='relu')(x)
     outputs = layers.Dense(32)(x)
    return outputs

composite = layers.CompositeLayer(layer_fn)
# Composite layer from a list of layers
composite = layers.CompositeLayer([
    layers.Dense(64, activation='relu'),
    layers.Dense(32)
])

Implementation notes

  1. Functional and CompositeLayer only depend on Function. There is no circular dependency between models and layers.

  2. The current implementation an intermediate step to make reviewing (diffs) easier. It isolates 4 functions in functional.py that are used by both CompositeLayer and Functional:

    1. compute_input_spec
    2. run_through_graph_with_training_and_mask
    3. function_from_config
    4. serialize_functional_config

      With this approach, no changes are made to the Functional Model class hierarchy.

      The next step would be to move these 4 functions to CompositeLayer, then base Functional on CompositeLayer instead of Function. This will also allow the unification of Functional and Sequential models since both will be based on CompositeLayer and have a Function once build() is called. Code explicitly testing for Functional or Sequential can then be removed throughout the code base and replaced with isinstance(obj, CompositeLayer) and obj.built
  3. plot_model and clone_model functionality were adjusted to work with CompositeLayers

  4. Tests were added for the main edge cases, namely subclasses of CompositeLayer and Functional with the layer graph instantiated inside or outside of the class, in various nesting scenarios, tested for serialization/deserialization and for processing through clone_model.
    Three bug fixes in Functional and clone_model were needed for the tests to pass:

    1. Cleanup of return values in Functional. A list of one element was sometimes returned. Model.assert_input_compatible let this through but Function.assert_input_compatible, which is stricter, did not. Single tensors are now returned in this case.
    2. Cleanup of "is this a functional-like construct?" tests in Functional which were buggy because, surprisingly, inspect.getfullargspec(Functional.__init__) returns an empty list instead of the expected (inputs, outputs) signature (see this colab).
    3. In clone_model subclasses of Functional are typically cloned as vanilla Functional. The same conservative behavior was adopted for CompositeLayer. There was a carve-out however for subclasses of Functional with a "functional-like" constructor, again with a buggy test. This is a niche use case of the niche clone_model functionality so the carve-out was simply removed for simplicity.
  5. Passing a list of inputs to a model expecting a dictionary of inputs seems to be allowed, as long as flattening the dict does not result in reordering. There is an explicit reordering test in functional._standardize_inputs (look for "sort"). Changing this in Functional is not possible at this point but I would consider disallowing this in CompositeLayer. Tests covering this behavior are functional_test.test_list_input_with_dict_build and composite_layer_test.test_list_input_with_dict_build. (Point for discussion)

  6. In functional.py, serialization and deserialization functions serialize_functional_config and function_from_config there is an explicit test for the type Functional which triggers a node_index adjustment. This was left untouched and not updated for CompositeLayer as I have not been able to find a situation where this code is useful. A test that triggers this condition was added in functions_test.py but it passes whether the conditional clauses are there or not.

  7. Optional inputs are supported although no new API was added for them. They can be declared by setting a manual input spec on the CompositeLayer:

# declare the first arg as optional
layer = CompositeLayer(...)
input_spec = [
    InputSpec(shape=(None, 2), optional=True),
    InputSpec(shape=(None, 2)),
]
layer.input_spec = input_spec
layer([None, value]) # this will now work

See composite_layer_test.py:test_optional_inputs for a working example.

Point for discussion: is this user-friendly enough or is a new API required?

@codecov-commenter
Copy link

codecov-commenter commented Mar 27, 2025

Codecov Report

Attention: Patch coverage is 80.83832% with 64 lines in your changes missing coverage. Please review.

Project coverage is 82.72%. Comparing base (6d26efb) to head (fbcbad5).

Files with missing lines Patch % Lines
keras/src/utils/model_visualization.py 0.00% 27 Missing ⚠️
keras/src/models/functional.py 85.61% 8 Missing and 13 partials ⚠️
keras/src/layers/core/composite_layer.py 90.40% 7 Missing and 5 partials ⚠️
keras/src/models/cloning.py 80.95% 2 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21099      +/-   ##
==========================================
+ Coverage   82.69%   82.72%   +0.02%     
==========================================
  Files         564      565       +1     
  Lines       54132    54285     +153     
  Branches     8411     8438      +27     
==========================================
+ Hits        44765    44907     +142     
- Misses       7294     7308      +14     
+ Partials     2073     2070       -3     
Flag Coverage Δ
keras 82.53% <80.83%> (+0.02%) ⬆️
keras-jax 64.15% <80.53%> (+0.07%) ⬆️
keras-numpy 59.23% <77.84%> (+0.13%) ⬆️
keras-openvino 33.72% <74.55%> (+0.84%) ⬆️
keras-tensorflow 64.43% <80.83%> (+0.06%) ⬆️
keras-torch 64.12% <80.53%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment on lines 212 to 220
# A subclassed Functional model is always cloned
# as a vanilla Functional model.
new_model = Functional(cloned_inputs, cloned_outputs,
name=model.name)
if model.compiled:
compiled_config = model.get_compile_config()
new_model.compile_from_config(compiled_config)
return new_model

Copy link
Contributor Author

@martin-gorner martin-gorner Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This piece of code was moved here from the end of _clone_functional_model when the function was renames _clone_function_object and repurposed for all functional types.

The test for functional_like_constructor(model.class) was removed. See implementation note 4.iii

Comment on lines 879 to 909
# This test is permissive. Any argument combination that
# could be a Functional init is allowed. This test will be
# followed by an actual call of the Functional constructor
# so the worst case is that args are not what they should
# be and the constructor fails with an explicit error message.
return (
(len(args) == 2)
(len(args) >= 2)
or (len(args) == 1 and "outputs" in kwargs)
or ("inputs" in kwargs and "outputs" in kwargs)
)

def functional_like_constructor(cls):
# This test is permissive. Any constructor that could be passed
# inputs and outputs is accepted. This test triggers Functional
# deserialization when whe know we have a functional config so
# it's OK to try anything that could work.
init_args = inspect.signature(cls.__init__).parameters
funct_init_args = (
("inputs" in init_args and "outputs" in init_args) or
("args" in init_args or "kwargs" in init_args))
return funct_init_args

def strict_functional_like_constructor(cls):
# This test is conservative. Only explcit "inputs" and "outputs"
# arguments with those names, are accepted. This test triggers Functional
# serialization and we want to do that in a subclass only when an explicitly
# functional __init__(inputs, outputs) constructor exists in the subclass.
init_args = inspect.signature(cls.__init__).parameters
funct_init_args = ("inputs" in init_args and "outputs" in init_args)
return funct_init_args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleanup of the "is functional-like" logic. See implementation note 4.ii

Comment on lines +224 to +225
def __init__(self, inputs, outputs, *args, param=1, **kwargs):
super().__init__(inputs, outputs, *args, **kwargs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The carve-out for functional serialization of subclasses of Functional that have a functional-like constructor was constrained to constructors with explicit inputs and outputs arguments, which are the only ones we can test for. See implementation note 4.ii

Comment on lines +254 to +258
# No way to detect that this can be serialized functionnally
# since the graph could have been created inside the custom
# __init__ with the same __init__ args.
config = model.get_config()
self.assertFalse(has_functional_config_keys(config))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In all other cases of subclassing Functional, functional serialization is NOT triggered since it is not possible to detect wether the layer graph is created outside of the class (in which case functional serialization would be useful) and when the graph is created inside of init in which case regular serialization that calls init at the end is enough.

@martin-gorner
Copy link
Contributor Author

martin-gorner commented Apr 1, 2025

Here is a demo Colab showing how a user can patch an CompositeLayer LLM to enable LoRA:
https://colab.research.google.com/drive/1USgG4S9j3XAUqpUZlvhAbLWguV9Gjc28?usp=sharing

@martin-gorner
Copy link
Contributor Author

martin-gorner commented Apr 3, 2025

I added SVF to the demo Colab.

SVF = Singular Value fine-tuning - explanations here, it's an SVD-based variant of LoRA.

It's 36 lines of code in total, 17 for the SVF layer, 7 to patch it into a pretrained model, 12 lines to patch it out, all done with user-level APIs.

(and BTW, the code format failure is not me, it's because of an OpenVino installation warning)

@mattdangerw
Copy link
Member

mattdangerw commented Apr 4, 2025

Hey @martin-gorner ! Probably someone else should review this in more detail, but have been doing some thinking on the UX (and only the UX so far). Still mulling, for now just some questions/thoughts...

Is there a performance issue with using a keras.Model as a layer that never touches the training APIs? Or is it just about simplicity? Basically do the training/saving APIs that are untouched ever get in the way?

I know that just using today's symbols won't cover all of what you are adding here--in particular building a functional from an unknown shape. And we should figure that out. But adding a Sequential as a sub-component to a larger model is already a pattern in a lot of Keras code. Are we saying that's bad with this PR or not really?

Also is there a reason we are pushing sequential and functional style layer construction onto the same symbol? Sequential is separate in our modeling APIs, why is it fused onto a single class in this PR? Seems mildly inconsistent.

# A reusable composite layer
class MyCompositeLayer(CompositeLayer):
    @staticmethod
    def my_layer_fn(inputs):
        x = layers.Dense(5)(inputs)
        return layers.Dense(4)(x)

    def __init__(self, **kwargs):
        super().__init__(MyCompositeLayer.my_layer_fn, **kwargs)

This feels off for a couple reasons. For one, recommending a staticmethod annotated class function is a lot of cognitive load for something we'd want to ideally be a simple as possible.

And it's unclear how this would work for a reusable functional component with a lot of config. If you were writing a block layer with 5 - 10 arguments related to feature sizes, activations, residuals, etc. How would you pass that information to the functional building method?

Lastly, I wonder if we'd want more predictability in method names for a subclasses of MyCompositeLayer. The nice thing about subclassed layers today with __init__/build/call methods is you can subclass, chain to super as needed and either augment or overwrite one of these fixed methods and not the others. We lack that property here.

@gbaned gbaned added this to PR Queue Apr 7, 2025
@github-project-automation github-project-automation bot moved this to Assigned Reviewer in PR Queue Apr 7, 2025
@martin-gorner
Copy link
Contributor Author

martin-gorner commented Apr 7, 2025

Hi Matt,

"Are we saying that[using a Model as a composite layer]'s bad with this PR or not really?"

One downside is that a functional model must have pre-defined inputs, whereas a Keras Layer is allowed to have its input shape deferred until build() is called. This makes using layers much more practical. CompositeLayer retains this property.

But apart from that there is nothing wrong in using a Model as composite layer today. Initially, it just did not feel very clean so I set out to try out what a proper CompositeLayer implementation would look like. In the process, I discovered that having a CompositeLayer could also clarify the internal class hierarchy. This is the class hierarchy I am suggesting down the line:

Functional(CompositeLayer, Trainer)
CompositeLayer(Layer) has a Function
Sequential being just syntactic sugar above Functional

Note: this PR does not implement this class hierarchy. It does not change the Functional class hierarchy at all, to make diffs and reviews easier, but once CompositeLayer is accepted, rebasing Functional and Sequential on CompositeLayer is easy.

This refactoring could unify Functional and Sequential in a seamless way. The artificial separation between these two classes is something power users building frameworks on top of keras have complained about in the past. There is also ugly code in a number of places (for ex cloning.py) that specifically tests for Sequential and offer an inferior implmentaion of graph cloning in that case (no recursivity).

is there a reason we are pushing sequential and functional style layer construction onto the same symbol

Initially I just had CompositeLayer(layer_fn) but then extending it to a sequence of layers seemed so straightforward that I added that. Which then led to what I think is an architecturally sound way of unifying Functional and Sequential down the line (my remarks above). So I kept it.

reusable functional component with a lot of config

I don't have preconceived ideas about how to package CompositeLayers as a custom class. I'm currently trying to write a full LLM in this style to see what works best. Would you like to suggest something?

@martin-gorner
Copy link
Contributor Author

In trying to write a full LLM functionally, so as to make LoRA and its modern variants easily implementable by the user, the biggest challenge right now is that it's not possible to have a layer that accepts optional inputs AND is defined in a functional way as a graph of layers (whether the container is Functional or CompositeLayer does not matter). This is different from a graph of layers that can contain custom layers with optional inputs which is what currently exists in unit tests.

This kind of construct is not possible:

    x, cache = inputs
    if cache is not None:
        outputs = MyCacheLookupLayer()(x, cache)
    else:
        outputs = Dense(n)(x)
    # then make a functional model/layer from (inputs, outputs)

This means that the LLM must implement two separate backbones, one with caches, one without, which introduces complexity. Ideally, the LLM would have a single backbone implementation, that could be compiled with or without caches. I don't see how to do this yet, with an implementation that is purely functional all the way to the Dense layers, so that they can easily be swapped/wrapped for their LoRA-ified versions. Any ideas?

@mattdangerw
Copy link
Member

mattdangerw commented Apr 10, 2025

One downside is that a functional model must have pre-defined inputs, whereas a Keras Layer is allowed to have its input shape deferred until build() is called.

Yeah this seems like a cool thing to solve. But you wouldn't necessarily need to solve it via a layer/model symbol split I think.

Sequential being just syntactic sugar above Functional

I think this has always been desired actually. If we could make Sequential be a subclass of Functional today we would, but IIRC there might have been some old but common uses of Sequential that made it hard to do that? But if we can make that change in a backwards compat friendly way I think this is a standalone piece of work we could grab any time.

This means that the LLM must implement two separate backbones, one with caches, one without, which introduces complexity.

Yeah agreed working with caches is a big pain point. I haven't even been worried about "functional all the way down," I think there's enough issues with optional inputs today for functional that would make it not fully work as a solution for cached inputs. (but also, being able to write this functional all the way down would be cool too!)

Any ideas?

Ignoring sequential and whether we'd want a separate layer-not-model solution for now (I think they are largely orthogonal), here's an idea.

class MLP(keras.Functional):
    def __init__(self, inner_units):
        super().__init__()
        self.inner_units = inner_units

    def functional_build(self, inputs):
        features = inputs.shape[-1]
        x = layers.Dense(self.inner_units, activation="relu")(inputs)
        x = layers.Dense(self.features)(x)
        return x + inputs  # Residual.

    def get_config(self):
        config = super().get_config()
        config.update("inner_units": self.inner_units)
        return config

inputs = keras.Input(shape=(8,))
symbolic_outputs = MLP(32)(inputs)
real_outputs = MLP(32)(np.ones(16, 4))

functional_build would be called on symbolic inputs only when the layer is called. We'd still support passing inputs and outputs to __init__ for backwards compat. You could also call self.functional_build(inputs) in init if you want a subclassed functional with fixed inputs that builds on construction. I guess we should only support a single positional argument when calling (with arbitrary tree structure) to keep in line with current functional behavior.

Nice things -- this code looks really similar to current UX with subclassed layers. It also separates config from call graph in different methods, which makes it easier to subclass your subclass of functional (a real problem for KerasHub).

It does not support writing "build" logic that is dependent on input shape without making a subclass, which is a key difference from your design I think. That's consistent with the keras.Layer UX perhaps, but debatable.

@martin-gorner
Copy link
Contributor Author

martin-gorner commented Apr 10, 2025

Here you go, LLM, functional all the way down (colab):
CompositeLayer Transformer, Functional end-to-end with text gen, fine-tuning and graph rewiring for custom PEFT (public)

This is a working example. The backbone is entirely Functional which means that the layer graph is editable and the Colab shows two custom LoRA-like techniques. The KV caches and text generation works, as well as training. JIT compilation too.

WDYT?

Note: I did not think much about class-packaging the layers in this example. It's mostly just function calling. Your functional_build suggestion above is interesting. Let me think through it with this real-world example.

@mattdangerw
Copy link
Member

WDYT?

Still thinking and trying to organize my thoughts.

In the meantime, I've been playing with many little prototypes to try and make this simpler from the end user side. Code is probably not complete, but that's not really the point. I'm trying to find a good UX.

https://colab.research.google.com/gist/mattdangerw/3ea95c6d1fc9fa8c6cad34ff3b3c457e/functional-layer-demo.ipynb

I'm not fully sure what I think of it yet, but food for thought.

@martin-gorner
Copy link
Contributor Author

martin-gorner commented Apr 13, 2025

I managed to get rid of the double backbone, by using optional inputs: functional LLM all the way down with a single backbone.

The trick was to put all the if cache is not None statements in custom layers and do the rest functionally. Luckily, it's easy to do in the Transformer architecture with just three small custom layers: AttentionMask, AttentionDotProduct and KVCacheLookup.

Thank you for your colab with UX ideas! - I'll send comments tomorrow, it's getting a bit late.

@martin-gorner
Copy link
Contributor Author

martin-gorner commented Apr 14, 2025

Food for thought regarding functional layer packaging with config etc: I added backbone saving to the colab. It's done using plain model.save and models.load_model since nothing else is required (for the backbone). And it restores the backbone with all of its config like embedding_dim, nb_heads, etc. Maybe this could be a recipe for community pre-trained models publishing (backbones are loadable with just load_model, no need to also publish a code library)?

So regarding your experiments, I wonder if packaging a functional layer with config attributes like filters or residual is really necessary since the same attributes will also be saved in the functional graph. Isn't the custom class doing more harm (config duplication) than good here?

I'm not saying I have a definitive recommendation though. People will want to give a class name to their backbone implementation etc.

@martin-gorner
Copy link
Contributor Author

Adding some implementation notes and thoughts:

The API that clones the layer graph, called clone_model, needs to handle functional sub-layers or sub-models (instances of Functional or CompositeLayer) in a well defined way. The layer graph inside them gets copied, that is not problematic, but what should happen with the Functional or CompositeLayer container class?

  • The currently implemented behavior is to re-create a fresh Functional or CompositeLayer object when running clone_model. This is deterministic, but may be surprising. Subclasses of Functional or CompositeLayer will be turned into vanilla Functional or CompositeLayer objects, losing their class name and any attributes or methods the subclass had.
  • Another possible behavior is to clone the Functional or CompositeLayer objects. Doing this through the Keras traditional get_config/from_config mechanism is problematic since users would most probably implement from_config as something that recreates the functional layer graph. Here, we already know how to copy the layer graph so calling a function that will overwrite this copy is not helpful. This would have to be some other form of cloning.

@mattdangerw
Copy link
Member

mattdangerw commented Apr 17, 2025

The currently implemented behavior is to re-create a fresh Functional or CompositeLayer object when running clone_model. This is deterministic, but may be surprising. Subclasses of Functional or CompositeLayer will be turned into vanilla Functional or CompositeLayer objects, losing their class name and any attributes or methods the subclass had.

Agreed, this is bad. For KerasHub this is to the point of being something we would never use I think. There is a lot of stuff attached to model and layer class implementations today. A very non-exhaustive list -- stable diffusion generation subroutines that are not part of the graph, conv net pyramid output helpers, preset and lora saving/loading routines (conceptually, load parts of a save model), vae helpers, helpers to get certain layer subcomponents without needing the layer name/index. I'm sure some of this stuff could be pulled into the graph, and some could be gotten rid of. But IMO, at the end of the day, attaching functionality to classes is a very normal thing to do our developers will expect, and any workflow that strips classes back to base classes will be very limited.

Another possible behavior is to clone the Functional or CompositeLayer objects. Doing this through the Keras traditional get_config/from_config mechanism is problematic since users would most probably implement from_config as something that recreates the functional layer graph. Here, we already know how to copy the layer graph so calling a function that will overwrite this copy is not helpful. This would have to be some other form of cloning.

This is the meat of the experiments I was running. Going through get_config() and from_config() is a good thing even if there's redundancy with the function representation we could make -- it makes these subclasses behave like any other subclass. We would want this for saving. I don't think it's the primary reason but there are some in the weeds cases here too--for KerasHub LLMs its quite important to be able to load a model saved at full precision at half precision without some form of cast after load.

I'm of the option that having from_config() hit the functional layer graph recreation is a good thing. It's gives users the same mental model for restoring as any other subclass, and an ability to handle legacy arg names, etc. The only exception is the cloning API. If we want this "clone as rewire" approach (though I'm still not sure the experience we are showing here is the right one), we need a way to recreate the object without recreating the graph. Here is one of the many benefits of having a fixed build method with a reliable name. We can skip it, and attach the cloned graph instead. This is a little gross--we need to assume the functional_build has no side effects besides graph creation. But I don't see a great alternative. The most compelling alternative I see is to rethink the whole workflow. Let cloning be actual cloning, have it hit the same serialize/deserialize code paths as saving, and explore a separate API for mutating the graph and/or replacing layers in place.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Assigned Reviewer
Development

Successfully merging this pull request may close these issues.

5 participants