Skip to content

Commit 7004f52

Browse files
committed
Fix optimizer docstring
1 parent 6ec0f46 commit 7004f52

File tree

4 files changed

+71
-15
lines changed

4 files changed

+71
-15
lines changed

keras/src/backend/jax/optimizer.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1+
"""A class for JAX specific optimizer logic.
2+
3+
Its purpose is to route around statelessness
4+
requirements in cond ops used for EMA handling
5+
and gradient accumulation handling. We do this
6+
by skipping conditionals entirely.
7+
"""
8+
19
import jax
210
from jax import numpy as jnp
311

412
from keras.src.optimizers import base_optimizer
513

614

715
class JaxOptimizer(base_optimizer.BaseOptimizer):
8-
"""A class for JAX specific optimizer logic.
9-
10-
Its purpose is to route around statelessness
11-
requirements in cond ops used for EMA handling
12-
and gradient accumulation handling. We do this
13-
by skipping conditionals entirely.
14-
"""
1516

1617
def _backend_apply_gradients(self, grads, trainable_variables):
1718
if self.gradient_accumulation_steps:

keras/src/backend/tensorflow/optimizer.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
"""A class for Tensorflow specific optimizer logic.
2+
3+
The major behavior change for this class is for tf.distribute.
4+
5+
It will override methods from base Keras core Optimizer,
6+
which provide distribute specific functionality, e.g. variable
7+
creation, loss reduction, etc.
8+
"""
9+
110
import warnings
211

312
import tensorflow as tf
@@ -9,14 +18,6 @@
918

1019

1120
class TFOptimizer(KerasAutoTrackable, base_optimizer.BaseOptimizer):
12-
"""A class for Tensorflow specific optimizer logic.
13-
14-
The major behavior change for this class is for tf.distribute.
15-
16-
It will override methods from base Keras core Optimizer,
17-
which provide distribute specific functionality, e.g. variable
18-
creation, loss reduction, etc.
19-
"""
2021

2122
def __init__(self, *args, **kwargs):
2223
super().__init__(*args, **kwargs)

keras/src/optimizers/base_optimizer.py

+53
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,59 @@
1212

1313

1414
class BaseOptimizer(KerasSaveable):
15+
"""Abstract optimizer base class.
16+
17+
If you intend to create your own optimization algorithm, please inherit from
18+
this class and override the following methods:
19+
20+
- `build`: Create your optimizer-related variables, such as momentum
21+
variables in the SGD optimizer.
22+
- `update_step`: Implement your optimizer's variable updating logic.
23+
- `get_config`: serialization of the optimizer.
24+
25+
Example:
26+
27+
```python
28+
class SGD(Optimizer):
29+
def __init__(self, **kwargs):
30+
super().__init__(**kwargs)
31+
self.momentum = 0.9
32+
33+
def build(self, variables):
34+
super().build(variables)
35+
self.momentums = []
36+
for variable in variables:
37+
self.momentums.append(
38+
self.add_variable_from_reference(
39+
reference_variable=variable, name="momentum"
40+
)
41+
)
42+
43+
def update_step(self, gradient, variable, learning_rate):
44+
learning_rate = ops.cast(learning_rate, variable.dtype)
45+
gradient = ops.cast(gradient, variable.dtype)
46+
m = self.momentums[self._get_variable_index(variable)]
47+
self.assign(
48+
m,
49+
ops.subtract(
50+
ops.multiply(m, ops.cast(self.momentum, variable.dtype)),
51+
ops.multiply(gradient, learning_rate),
52+
),
53+
)
54+
self.assign_add(variable, m)
55+
56+
def get_config(self):
57+
config = super().get_config()
58+
config.update(
59+
{
60+
"momentum": self.momentum,
61+
"nesterov": self.nesterov,
62+
}
63+
)
64+
return config
65+
```
66+
"""
67+
1568
def __init__(
1669
self,
1770
learning_rate,

keras/src/optimizers/optimizer.py

+1
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@ class Optimizer(BackendOptimizer, base_optimizer.BaseOptimizer):
2323
pass
2424

2525

26+
Optimizer.__doc__ = base_optimizer.BaseOptimizer.__doc__
2627
base_optimizer_keyword_args = base_optimizer.base_optimizer_keyword_args

0 commit comments

Comments
 (0)