Skip to content

Commit 4d720fa

Browse files
fabianpOptaxDev
authored and
OptaxDev
committed
Fix init_value and end_value in cosine decay
Problem: * the starting value was not always init_value * the last value was not always end_value Solution: * changed the formulas (they are now compatible with the pytorch implementation https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html) * added test to check that starting and end value coincide with `init_value` and `end_value` respectively Misc: renamed alpha -> end_value in warmup_cosine_decay_schedule PiperOrigin-RevId: 619109148
1 parent 5d8c7a3 commit 4d720fa

File tree

2 files changed

+38
-17
lines changed

2 files changed

+38
-17
lines changed

optax/schedules/_schedule.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -266,11 +266,11 @@ def cosine_decay_schedule(
266266
267267
.. math::
268268
269-
\frac{I (1 - E)}{2}(1+\cos(\pi\,\frac{t}{T})^p) + E\,,
269+
\frac{(I - E)}{2}(1+\cos(\pi\,\frac{t}{T})^p) + E\,,
270270
271271
where :math:`T` is the number of decay steps (``decay_steps``), :math:`p` is
272272
the ``exponent``, :math:`I` is the initial value (``init_value``) and
273-
:math:`E` is the end value,.
273+
:math:`E` is the end value (``end_value``).
274274
275275
References:
276276
Loshchilov et al., `SGDR: Stochastic Gradient Descent with Warm Restarts
@@ -286,8 +286,8 @@ def cosine_decay_schedule(
286286
``t`` is the current timestep and ``T`` is the ``decay_steps``. The
287287
exponent modifies this to be ``(0.5 * (1 + cos(pi * t/T))) ** exponent``.
288288
Defaults to 1.0.
289-
alpha: The minimum value of the multiplier used to adjust the
290-
learning rate. Defaults to 0.0.
289+
alpha: Deprecated, use end_value instead. The minimum value of the
290+
multiplier used to adjust the learning rate. Defaults to 0.0.
291291
292292
Returns:
293293
schedule
@@ -316,8 +316,7 @@ def cosine_decay_schedule(
316316
def schedule(count):
317317
count = jnp.minimum(count, decay_steps)
318318
cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * count / decay_steps))
319-
decayed = (1 - end_value) * cosine_decay ** exponent + end_value
320-
return init_value * decayed
319+
return (init_value - end_value) * cosine_decay ** exponent + end_value
321320

322321
return schedule
323322

@@ -501,7 +500,6 @@ def warmup_cosine_decay_schedule(
501500
schedule
502501
A function that maps step counts to values
503502
"""
504-
alpha = 0. if peak_value == 0. else end_value / peak_value
505503
schedules = [
506504
linear_schedule(
507505
init_value=init_value,
@@ -511,7 +509,7 @@ def warmup_cosine_decay_schedule(
511509
cosine_decay_schedule(
512510
init_value=peak_value,
513511
decay_steps=decay_steps - warmup_steps,
514-
alpha=alpha,
512+
end_value=end_value,
515513
exponent=exponent,
516514
),
517515
]

optax/schedules/_schedule_test.py

+32-9
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,24 @@ def test_immutable_count(self):
300300

301301
class CosineDecayTest(chex.TestCase):
302302

303+
@chex.all_variants
304+
def test_init_value_end_value(self):
305+
"""Check cosine schedule decay for the entire training schedule."""
306+
initial_value = 1.5
307+
end_value = 0.2
308+
num_steps = 10
309+
schedule_fn = self.variant(
310+
_schedule.cosine_decay_schedule(initial_value, num_steps, end_value))
311+
# Test that generated values equal the expected schedule values.
312+
generated_vals = []
313+
for count in range(num_steps + 1):
314+
# Compute next value.
315+
generated_vals.append(schedule_fn(count))
316+
317+
# Test that the first and last values are correct.
318+
self.assertAlmostEqual(generated_vals[0], initial_value)
319+
self.assertAlmostEqual(generated_vals[-1], end_value)
320+
303321
@chex.all_variants
304322
def test_decay_count_smaller_count(self):
305323
"""Check cosine schedule decay for the entire training schedule."""
@@ -345,23 +363,28 @@ def test_decay_count_greater_count(self):
345363
def test_decay_count_greater_count_with_end_value(self):
346364
"""Check cosine schedule decay for a part of the training schedule."""
347365
# Get schedule function.
348-
initial_value = 0.1
366+
initial_value = 0.2
367+
end_value = 0.1
368+
num_steps = 5
349369
schedule_fn = self.variant(
350-
_schedule.cosine_decay_schedule(initial_value, 5, 0.1))
370+
_schedule.cosine_decay_schedule(initial_value, num_steps, end_value))
351371
# Test that generated values equal the expected schedule values.
352372
generated_vals = []
353-
for count in range(12):
373+
for count in range(2 * num_steps):
354374
# Compute next value.
355375
generated_vals.append(schedule_fn(count))
356376

357377
# Test output.
358-
expected_multipliers = np.array(
359-
0.5 + 0.5 * np.cos(
360-
np.pi * np.array(
361-
[0.0, 0.2, 0.4, 0.6, 0.8, 1., 1., 1., 1., 1., 1., 1.])))
362-
expected_multipliers = 0.9 * expected_multipliers + 0.1
378+
cos_values = 0.5 * (1 + np.cos(np.pi * np.linspace(0, 1, num_steps + 1)))
379+
expected_values = (
380+
(initial_value - end_value) * cos_values + end_value
381+
)
382+
# padd with [end_value] at the end.
383+
expected_values = np.concatenate(
384+
(expected_values, [end_value] * (num_steps - 1))
385+
)
363386
np.testing.assert_allclose(
364-
initial_value * expected_multipliers,
387+
expected_values,
365388
np.array(generated_vals), atol=1e-3)
366389

367390
def test_cosine_alpha_exception(self):

0 commit comments

Comments
 (0)