Skip to content
Merged
30 changes: 24 additions & 6 deletions docs/source/reducing_memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,30 +102,48 @@ For more information, see [Liger Kernel Integration](liger_kernel_integration).
To use Liger for reducing peak memory usage, use the following code snippet:

<hfoptions id="liger">
<hfoption id="SFT">

```python
from trl import SFTConfig

training_args = SFTConfig(..., use_liger_kernel=True)
```

</hfoption>
<hfoption id="DPO">

```python
from trl import DPOConfig

training_args = DPOConfig(..., use_liger_loss=True)
training_args = DPOConfig(..., use_liger_kernel=True)
```

</hfoption>
<hfoption id="GRPO">

```python
from trl import GRPOConfig

training_args = GRPOConfig(..., use_liger_loss=True)
training_args = GRPOConfig(..., use_liger_kernel=True)
```

</hfoption>
<hfoption id="KTO">

```python
from trl import KTOConfig

training_args = KTOConfig(..., use_liger_loss=True)
training_args = KTOConfig(..., use_liger_kernel=True)
```

</hfoption>
<hfoption id="GKD">

```python
from trl import GKDConfig

training_args = GKDConfig(..., use_liger_kernel=True)
```

</hfoption>
Expand Down
8 changes: 4 additions & 4 deletions tests/slow/test_grpo_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ def teardown_method(self):

@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
@require_liger_kernel
def test_training_with_liger_grpo_loss(self, model_name):
def test_training_with_liger_grpo_kernel(self, model_name):
training_args = GRPOConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=3,
num_generations=3,
use_liger_loss=True,
use_liger_kernel=True,
max_completion_length=self.max_length,
report_to="none",
logging_strategy="no",
Expand Down Expand Up @@ -108,14 +108,14 @@ def test_training_with_liger_grpo_loss(self, model_name):
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
@require_liger_kernel
@require_peft
def test_training_with_liger_grpo_loss_and_peft(self, model_name):
def test_training_with_liger_grpo_kernel_and_peft(self, model_name):
from peft import LoraConfig, TaskType

training_args = GRPOConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=3,
num_generations=3,
use_liger_loss=True,
use_liger_kernel=True,
max_completion_length=self.max_length,
report_to="none",
logging_strategy="no",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def test_train_encoder_decoder_liger(self):
per_device_train_batch_size=2,
learning_rate=9e-1,
report_to="none",
use_liger_loss=True,
use_liger_kernel=True,
)
trainer = DPOTrainer(
model=model,
Expand Down Expand Up @@ -1330,7 +1330,7 @@ def test_dpo_trainer_with_liger(self, beta, loss_type):
learning_rate=9e-1,
eval_strategy="steps",
beta=beta,
use_liger_loss=True, # Enable Liger loss
use_liger_kernel=True, # Enable Liger kernel
loss_type=loss_type,
report_to="none",
)
Expand Down
13 changes: 8 additions & 5 deletions tests/test_gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os

import pytest
import torch
import torch.nn.functional as F
from datasets import load_dataset
Expand All @@ -29,9 +30,10 @@ class TestGKDTrainerGenerateOnPolicy(TrlTestCase):
@classmethod
def setup_class(cls):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
cls.device = "cuda" if torch.cuda.is_available() else "cpu"
cls.tokenizer = AutoTokenizer.from_pretrained(model_id)
cls.tokenizer.pad_token = cls.tokenizer.eos_token
cls.model = AutoModelForCausalLM.from_pretrained(model_id)
cls.model = AutoModelForCausalLM.from_pretrained(model_id).to(cls.device)
cls.generation_config = GenerationConfig(
max_new_tokens=20,
num_return_sequences=1,
Expand All @@ -44,8 +46,8 @@ def test_generate_on_policy_outputs_deterministic(self):
tokenized_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True)

inputs = {
"prompts": tokenized_prompts["input_ids"],
"prompt_attention_mask": tokenized_prompts["attention_mask"],
"prompts": tokenized_prompts["input_ids"].to(self.device),
"prompt_attention_mask": tokenized_prompts["attention_mask"].to(self.device),
}

# Set temperature to 0 for deterministic output
Expand Down Expand Up @@ -91,8 +93,8 @@ def test_generate_on_policy_outputs(self):
tokenized_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True)

inputs = {
"prompts": tokenized_prompts["input_ids"],
"attention_mask": tokenized_prompts["attention_mask"],
"prompts": tokenized_prompts["input_ids"].to(self.device),
"attention_mask": tokenized_prompts["attention_mask"].to(self.device),
}

outputs = GKDTrainer.generate_on_policy_outputs(
Expand Down Expand Up @@ -238,6 +240,7 @@ def test_gkd_trainer(self):
assert "model.safetensors" in os.listdir(self.tmp_dir + "/checkpoint-2")

@require_liger_kernel
@pytest.mark.xfail(reason="Computing the Liger loss spikes GPU memory usage, causing the test to run OOM.")
def test_gkd_trainer_with_liger(self):
training_args = GKDConfig(
output_dir=self.tmp_dir,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,7 +1460,7 @@ def reward_func(completions, **kwargs):
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
use_liger_loss=True, # enable Liger loss
use_liger_kernel=True, # enable Liger kernel
loss_type="bnpo", # default dapo is not supported yet
report_to="none",
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,11 +361,11 @@ def test_kto_lora_save(self):

@require_liger_kernel
def test_kto_trainer_with_liger(self):
"""Test KTO trainer with Liger loss enabled."""
"""Test KTO trainer with Liger kernel enabled."""
training_args = KTOConfig(
output_dir=self.tmp_dir,
report_to="none",
use_liger_loss=True, # Enable Liger loss
use_liger_kernel=True, # Enable Liger kernel
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference")
Expand Down
24 changes: 20 additions & 4 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Optional, Union
Expand Down Expand Up @@ -156,11 +157,17 @@ class DPOConfig(TrainingArguments):
[MPO](https://huggingface.co/papers/2411.10442)). The `loss_weights` parameter can be used to specify
corresponding weights for each loss type.

use_liger_loss (`bool`, *optional*, defaults to `False`):
use_liger_loss (`bool`, *optional*, defaults to `None`):
Whether to use Liger loss.

<Deprecated version="0.25.0">

Parameter `use_liger_loss` is deprecated and will be removed in version 0.28.0. Use `use_liger_kernel` instead.

</Deprecated>
base_model_attribute_name (`str`, *optional*, defaults to `"model"`):
Name of the attribute in the model that contains the base model. This is used to get the base model from
the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`.
the model when the model does not have a `get_decoder` method in the case when `use_liger_kernel` is `True`.
beta (`float`, *optional*, defaults to `0.1`):
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
Expand Down Expand Up @@ -378,15 +385,15 @@ class DPOConfig(TrainingArguments):
},
)
use_liger_loss: bool = field(
default=False,
default=None,
metadata={"help": "Whether to use Liger loss."},
)
base_model_attribute_name: str = field(
default="model",
metadata={
"help": "Name of the attribute in the model that contains the base model. This is used to get the base "
"model from the model when the model does not have a `get_decoder` method in the case when "
"`use_liger_loss` is `True`."
"`use_liger_kernel` is `True`."
},
)
beta: float = field(
Expand Down Expand Up @@ -510,4 +517,13 @@ def __post_init__(self):
f"Length of loss_weights list ({self.loss_weights}) must match number of loss types "
f"({loss_types})."
)

if self.use_liger_loss:
warnings.warn(
"The `use_liger_loss` argument is deprecated and will be removed in version 0.28.0. Please use "
"`use_liger_kernel` instead.",
FutureWarning,
stacklevel=2,
)
self.use_liger_kernel = self.use_liger_loss
super().__post_init__()
8 changes: 4 additions & 4 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,15 +378,15 @@ def __init__(
disable_dropout_in_model(self.ref_model)

# Liger kernel
if args.use_liger_loss:
if args.use_liger_kernel:
if not is_liger_kernel_available():
raise ImportError(
"You set `use_liger_loss=True` but the liger kernel is not available. "
"You set `use_liger_kernel=True` but the liger kernel is not available. "
"Please install liger-kernel first: `pip install liger-kernel`"
)
if args.loss_type not in ["sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"]:
raise ValueError(
"You set `use_liger_loss=True` but the loss type is not from `[sigmoid, apo_zero, apo_down, sppo_hard, nca_pair`. "
"You set `use_liger_kernel=True` but the loss type is not from `[sigmoid, apo_zero, apo_down, sppo_hard, nca_pair`. "
"Please set `loss_type='[sigmoid | apo_zero | apo_down | sppo_hard | nca_pair]'` to use the liger kernel."
)
self.dpo_loss_fn = LigerFusedLinearDPOLoss(
Expand Down Expand Up @@ -1730,7 +1730,7 @@ def get_batch_loss_metrics(
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}

if self.args.use_liger_loss:
if self.args.use_liger_kernel:
model_output = self._compute_loss_liger(model, batch)
losses = model_output["loss"]
chosen_rewards = model_output["chosen_rewards"]
Expand Down
26 changes: 21 additions & 5 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from dataclasses import dataclass, field
from typing import Optional, Union

Expand Down Expand Up @@ -220,8 +221,14 @@ class GRPOConfig(TrainingArguments):
position, improving results. Range: `[0.0-1.0]`. A value of `0.0` masks all but the highest entropy token;
`1.0` keeps all tokens. The paper recommends a value of `0.2`. If used with
`mask_truncated_completions=True`, only tokens from non-truncated completions are considered.
use_liger_loss (`bool`, *optional*, defaults to `False`):
Whether to use the Liger GRPO loss.
use_liger_loss (`bool`, *optional*, defaults to `None`):
Whether to use Liger loss.

<Deprecated version="0.25.0">

Parameter `use_liger_loss` is deprecated and will be removed in version 0.28.0. Use `use_liger_kernel` instead.

</Deprecated>
vllm_importance_sampling_correction (`bool`, *optional*, defaults to `True`):
Whether to apply Truncated Importance Sampling (TIS) between vLLM completion logprobs and recomputed
logprobs. [Your Efficient RL Framework Secretly Brings You Off-Policy RL
Expand Down Expand Up @@ -605,7 +612,7 @@ class GRPOConfig(TrainingArguments):
},
)
use_liger_loss: bool = field(
default=False,
default=None,
metadata={"help": "Whether to use the Liger GRPO loss."},
)
vllm_importance_sampling_correction: bool = field(
Expand Down Expand Up @@ -697,5 +704,14 @@ def __post_init__(self):
f"{self.num_generations}, which is less than the minimum required."
)

if self.delta is not None and self.use_liger_loss:
raise ValueError("Liger loss does not support two-sided GRPO loss yet.")
if self.use_liger_loss:
warnings.warn(
"The `use_liger_loss` argument is deprecated and will be removed in version 0.28.0. Please use "
"`use_liger_kernel` instead.",
FutureWarning,
stacklevel=2,
)
self.use_liger_kernel = self.use_liger_loss

if self.delta is not None and self.use_liger_kernel:
raise ValueError("Liger kernel does not support two-sided GRPO loss yet.")
12 changes: 6 additions & 6 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,17 +390,17 @@ def __init__(
self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode
self.vllm_importance_sampling_correction = args.vllm_importance_sampling_correction
self.vllm_importance_sampling_cap = args.vllm_importance_sampling_cap
self.use_liger_loss = args.use_liger_loss
self.use_liger_kernel = args.use_liger_kernel
self.loss_type = args.loss_type
self.scale_rewards = args.scale_rewards
self.importance_sampling_level = args.importance_sampling_level
self.mask_truncated_completions = args.mask_truncated_completions
self.top_entropy_quantile = args.top_entropy_quantile
if self.use_liger_loss and self.top_entropy_quantile < 1.0:
if self.use_liger_kernel and self.top_entropy_quantile < 1.0:
raise NotImplementedError(
"Liger Kernels don't currently support masking token positions based on entropy."
)
if self.use_liger_loss and not self.importance_sampling_level == "token":
if self.use_liger_kernel and not self.importance_sampling_level == "token":
raise NotImplementedError(
"Liger Kernels currently only support token-level importance sampling. Please set"
"`importance_sampling_level` to 'token'."
Expand Down Expand Up @@ -478,10 +478,10 @@ def __init__(
disable_dropout_in_model(self.ref_model)

# Liger loss
if self.use_liger_loss:
if self.use_liger_kernel:
if not is_liger_kernel_available():
raise ImportError(
"Liger is required to use `liger_loss` as the GRPO loss. Run `pip install liger-kernel`."
"Liger is required to use `use_liger_kernel` as the GRPO loss. Run `pip install liger-kernel`."
)
# redirect the model.module forward to the model forward to ensure pre-forward hooks are called
self._forward_redirection = _ForwardRedirection()
Expand Down Expand Up @@ -1720,7 +1720,7 @@ def compute_liger_loss(self, unwrapped_model, inputs):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
if self.use_liger_loss:
if self.use_liger_kernel:
# Compute the loss using the liger grpo loss
unwrapped_model = self.accelerator.unwrap_model(model)
return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs)
Expand Down
Loading
Loading