Skip to content

Commit 15564c8

Browse files
committed
[Algorithm] Expert Iteration and SFT
1 parent 92b52a0 commit 15564c8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+4209
-132
lines changed

docs/source/reference/llms.rst

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ transforms).
200200

201201
DataLoadingPrimer
202202
KLRewardTransform
203+
RetrieveLogProb
203204
MCPToolTransform
204205
BrowserTransform
205206
PythonInterpreter
@@ -256,6 +257,9 @@ LLM post training require some appropriate versions of the losses implemented in
256257
GRPO
257258
~~~~
258259

260+
The :class:`~torchrl.objectives.llm.GRPOLoss` class is a thin wrapper around the :class:`~torchrl.objectives.PPOLoss` class
261+
that codes the LLM-specific functionnalities.
262+
259263
.. currentmodule:: torchrl.objectives.llm
260264

261265
.. autosummary::
@@ -265,3 +269,24 @@ GRPO
265269
GRPOLoss
266270
GRPOLossOutput
267271
MCAdvantage
272+
273+
274+
SFT
275+
~~~
276+
277+
.. currentmodule:: torchrl.objectives.llm
278+
279+
.. autosummary::
280+
:toctree: generated/
281+
:template: rl_template.rst
282+
283+
SFTLoss
284+
SFTLossOutput
285+
286+
.. currentmodule:: torchrl.data.llm
287+
288+
.. autosummary::
289+
:toctree: generated/
290+
:template: rl_template.rst
291+
292+
TopKRewardSelector

sota-implementations/cql/online_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ optim:
4444
critic_lr: 3e-4
4545
weight_decay: 0.0
4646
batch_size: 256
47-
optim_steps_per_batch: 200
47+
optim_dialog_turns_per_batch: 200
4848

4949
# Policy and model
5050
model:

sota-implementations/dreamer/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ optimization:
2626
value_lr: 8e-5
2727
kl_scale: 1.0
2828
free_nats: 3.0
29-
optim_steps_per_batch: 80
29+
optim_dialog_turns_per_batch: 80
3030
gamma: 0.99
3131
lmbda: 0.95
3232
imagination_horizon: 15

sota-implementations/dreamer/dreamer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def main(cfg: DictConfig): # noqa: F821
137137
scaler3 = GradScaler()
138138

139139
init_random_frames = cfg.collector.init_random_frames
140-
optim_steps_per_batch = cfg.optimization.optim_steps_per_batch
140+
optim_dialog_turns_per_batch = cfg.optimization.optim_dialog_turns_per_batch
141141
grad_clip = cfg.optimization.grad_clip
142142
eval_iter = cfg.logger.eval_iter
143143
eval_rollout_steps = cfg.logger.eval_rollout_steps
@@ -179,7 +179,7 @@ def compile_rssms(module):
179179
t_loss_actor = 0.0
180180
t_loss_critic = 0.0
181181
t_loss_model = 0.0
182-
for _ in range(optim_steps_per_batch):
182+
for _ in range(optim_dialog_turns_per_batch):
183183
# sample from replay buffer
184184
t_sample_init = time.time()
185185
sampled_tensordict = replay_buffer.sample().reshape(-1, batch_length)

0 commit comments

Comments
 (0)