Skip to content

Commit 5180937

Browse files
authored
Knowledge distillation, fix and improve cross-entropy (#229)
1 parent 1550bd1 commit 5180937

File tree

11 files changed

+378
-175
lines changed

11 files changed

+378
-175
lines changed

fast_llm/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __enter__(self):
3232
global _AUTO_VALIDATE
3333
self._old_value = _AUTO_VALIDATE
3434
_AUTO_VALIDATE = False
35+
return _AUTO_VALIDATE
3536

3637
def __exit__(self, exc_type, exc_val, exc_tb):
3738
global _AUTO_VALIDATE

fast_llm/engine/distributed/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,6 @@ def _validate(self) -> None:
293293
if self.reference_config.reference_config is not None:
294294
self.reference_config = self.reference_config.reference_config
295295
assert self.reference_config.reference_config is None
296-
self.compare(self.reference_config, ValueError)
297296
self.distributed_dims = self.reference_config.distributed_dims
298297
else:
299298
self.distributed_dims = {}
@@ -368,6 +367,8 @@ def _validate(self) -> None:
368367

369368
super()._validate()
370369

370+
if self.reference_config is not None:
371+
self.compare(self.reference_config, ValueError)
371372
Assert.in_range(self.rank, 0, self.world_size)
372373
Assert.in_range(self.local_rank, 0, self.local_world_size)
373374

fast_llm/engine/multi_stage/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,8 @@ def _validate(self) -> None:
305305
self.pretrained.setup(self.model)
306306
self.pretrained.validate()
307307
if self.pretrained.path is not None:
308-
self.model = self.model.from_pretrained(self.pretrained, self.model)
308+
with NoAutoValidate():
309+
self.model = self.model.from_pretrained(self.pretrained, self.model)
309310
self._setup()
310311
super()._validate()
311312

fast_llm/engine/training/config.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -385,17 +385,15 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig):
385385

386386
def _validate(self) -> None:
387387
self.training.export.setup(self.model)
388-
self.model.validate()
388+
for reference_model in self.reference_models.values():
389+
_add_reference_distributed_to_pretrained(reference_model, self.model.distributed)
390+
super()._validate()
389391
if self.reference_models:
390392
# TODO: Add support.
391393
Assert.eq(self.model.distributed.pipeline_parallel, 1)
392394
# TODO: Check if these work.
393395
Assert.eq(self.model.distributed.tensor_parallel, 1)
394396
Assert.eq(self.model.distributed.sequence_data_parallel, 1)
395-
396-
for reference_model in self.reference_models.values():
397-
_add_reference_distributed_to_pretrained(reference_model, self.model.distributed)
398-
super()._validate()
399397
if self.run.experiment_dir is None:
400398
assert not self.training.checkpoint.enabled()
401399

@@ -431,13 +429,12 @@ def _add_reference_distributed_to_pretrained(pretrained: PretrainedFastLLMModelC
431429

432430
def new_setup():
433431
# Make sure the distributed config isn't set
434-
# TODO!!!!!!!!!!!!!: Uncomment after #205
435-
# pretrained.model.distributed.validate()
436-
# Assert.leq(pretrained.model.distributed.to_dict().keys(), {"world_size", "rank", "local_world_size"})
432+
pretrained.model.distributed.validate()
433+
Assert.leq(pretrained.model.distributed.to_dict().keys(), {"world_size", "rank", "local_world_size"})
437434
with NoAutoValidate():
438435
pretrained.model.distributed = distributed.to_copy()
439436
# Allow sharing the `Distributed` instance.
440437
pretrained.model.distributed.reference_config = distributed
441438
old_setup()
442439

443-
pretrained._setup = new_setup
440+
object.__setattr__(pretrained, "_setup", new_setup)

fast_llm/functional/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,9 @@ class CrossEntropyImpl(str, enum.Enum):
9191
torch = "torch"
9292
fused = "fused"
9393
triton = "triton"
94+
95+
96+
class TargetFormat(enum.StrEnum):
97+
labels = "labels"
98+
logits = "logits"
99+
probabilities = "probabilities"

fast_llm/functional/cross_entropy.py

Lines changed: 108 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.autograd
44

55
from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce
6-
from fast_llm.functional.config import CrossEntropyImpl
6+
from fast_llm.functional.config import CrossEntropyImpl, TargetFormat
77
from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward
88
from fast_llm.utils import Assert
99

@@ -12,34 +12,67 @@ def torch_cross_entropy_forward_backward(
1212
logits: torch.Tensor,
1313
target: torch.Tensor,
1414
grad_output: float | None,
15-
logits_scale_factor: float = 1.0,
15+
logits_scale_factor: float,
16+
target_format: TargetFormat,
1617
) -> tuple[torch.Tensor, torch.Tensor | None]:
1718
"""
1819
A wrapper for the pytorch implementation of cross-entropy.
1920
The cross-entropy kernels themselves are well-optimized, but the need for explicit casting
2021
and separate forward and backward kernels lead to poor performance.
21-
TODO: loss masking only works for this method if the masking index is set to -100.
22+
TODO: loss masking only works for with labels format and if the masking index is set to -100.
2223
"""
2324
# Torch compile doesn't understand this.
24-
with torch.enable_grad():
25-
logits_ = logits.float().detach().requires_grad_()
26-
if logits_scale_factor != 1.0:
27-
logits_ *= logits_scale_factor
25+
with torch.set_grad_enabled(grad_output is not None):
26+
logits_ = logits.float().detach().requires_grad_(grad_output is not None)
27+
if target_format == TargetFormat.logits:
28+
if logits_scale_factor != 1.0:
29+
target = target * logits_scale_factor
30+
target = torch.softmax(target, dim=-1)
31+
loss = torch.nn.functional.cross_entropy(
32+
logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target
33+
).mean()
2834
if grad_output is None:
29-
loss = None
35+
grad = None
3036
else:
31-
loss = torch.nn.functional.cross_entropy(logits_, target).mean()
3237
loss.backward(torch.full_like(loss, grad_output))
33-
loss.detach_()
34-
return loss.detach(), logits_.grad.detach().to(logits.dtype)
38+
grad = logits_.grad.detach().to(logits.dtype)
39+
return loss.detach_(), grad
40+
41+
42+
# @torch.compile
43+
def _fused_softmax_base(
44+
logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1
45+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
46+
logits = logits.float()
47+
if logits_scale_factor != 1.0:
48+
logits *= logits_scale_factor
49+
logits_max = torch.max(logits, dim=dim, keepdim=True)[0]
50+
if group is not None:
51+
all_reduce(logits_max, op=ReduceOp.MAX, group=group)
52+
logits_norm = (logits - logits_max).float()
53+
exp_logits = logits_norm.exp()
54+
sum_exp_logits = exp_logits.sum(dim=dim, keepdim=True)
55+
if group is not None:
56+
all_reduce(sum_exp_logits, op=ReduceOp.SUM, group=group)
57+
return logits_norm, exp_logits, sum_exp_logits
58+
59+
60+
# @torch.compile
61+
def fused_softmax(
62+
logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup = None, dim: int = -1
63+
) -> torch.Tensor:
64+
_, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group, dim)
65+
return exp_logits / sum_exp_logits
3566

3667

3768
@torch.compile
3869
def fused_cross_entropy_forward_backward(
3970
logits: torch.Tensor,
4071
target: torch.Tensor,
4172
grad_output: float | None,
42-
logits_scale_factor: float = 1.0,
73+
logits_scale_factor: float,
74+
target_format: TargetFormat,
75+
group: ProcessGroup | None = None,
4376
) -> tuple[torch.Tensor, torch.Tensor | None]:
4477
"""
4578
A fused implementation of cross-entropy with torch compile.
@@ -48,82 +81,67 @@ def fused_cross_entropy_forward_backward(
4881
"""
4982
# Do the forward and backward passes all at once, and fused with dtype conversion.
5083
# Way faster and more memory-efficient than the pytorch version.
51-
loss_mask = target >= 0
52-
# Ignore_index can go out of bounds, so set masked values to zero.
53-
target = (target * loss_mask).unsqueeze(1)
54-
logits_norm = logits.sub(torch.max(logits, dim=-1)[0].unsqueeze(dim=-1)).float()
55-
if logits_scale_factor != 1.0:
56-
logits_norm *= logits_scale_factor
57-
exp_logits = logits_norm.exp()
58-
sum_exp_logits = exp_logits.sum(dim=-1)
59-
60-
if grad_output is None:
61-
grad = None
62-
else:
63-
exp_logits = exp_logits.scatter(1, target, exp_logits.gather(1, target) - sum_exp_logits.unsqueeze(dim=-1))
64-
# exp_logits[torch.arange(0, logits.size(0), device=logits.device), target.squeeze(dim=-1)]-=sum_exp_logits
65-
exp_logits = exp_logits.mul((grad_output / logits.size(0)) / sum_exp_logits.unsqueeze(dim=-1))
66-
67-
if logits_scale_factor != 1.0:
68-
exp_logits *= logits_scale_factor
69-
70-
grad = torch.where(loss_mask.unsqueeze(1), exp_logits.to(logits.dtype), 0)
71-
72-
per_sample_loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)) * loss_mask
73-
74-
return per_sample_loss.mean(), grad
7584

85+
logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group)
7686

77-
@torch.compile
78-
def parallel_cross_entropy_forward_backward(
79-
logits: torch.Tensor,
80-
target: torch.Tensor,
81-
grad_output: float | None,
82-
group: ProcessGroup,
83-
logits_scale_factor: float = 1.0,
84-
) -> tuple[torch.Tensor, torch.Tensor | None]:
85-
"""
86-
A fused implementation of cross-entropy with torch compile, with support for tensor parallelism.
87-
Comes with a noticeable overhead, but reduces memory usage.
88-
"""
89-
# TODO: Compiled version incorrect for some inputs (32 bit indexing issue?).
90-
# TODO: Optimize, overlap/combine reductions
91-
loss_mask = target >= 0
92-
target = target.unsqueeze(1)
93-
94-
logits_max = torch.max(logits, dim=-1)[0]
95-
all_reduce(logits_max, op=ReduceOp.MAX, group=group)
96-
logits_norm = logits.sub(logits_max.unsqueeze(dim=-1)).float()
97-
if logits_scale_factor != 1.0:
98-
logits_norm *= logits_scale_factor
99-
100-
exp_logits = logits_norm.exp()
101-
sum_exp_logits = exp_logits.sum(dim=-1)
102-
all_reduce(sum_exp_logits, op=ReduceOp.SUM, group=group)
87+
if target_format == TargetFormat.logits:
88+
target = fused_softmax(target, logits_scale_factor, group)
10389

104-
# Mask the target (fused)
105-
# TODO: Could mask earlier on cpu or overlap with reduce?
106-
vocab_start_index = logits.size(-1) * group.rank()
107-
target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1))
108-
target = (target - vocab_start_index) * target_mask
90+
if target_format == TargetFormat.labels:
91+
target = target.unsqueeze(-1)
92+
loss_mask = target >= 0
93+
if group is None:
94+
# Keep values within range for scatter and gather ops to work.
95+
target = target * loss_mask
96+
target_mask = None
97+
else:
98+
# Mask the target (fused)
99+
# TODO: Could mask earlier on cpu or overlap with reduce?
100+
vocab_start_index = logits.size(-1) * group.rank()
101+
target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1))
102+
target = (target - vocab_start_index) * target_mask
103+
else:
104+
# TODO: Support masking
105+
loss_mask = None
106+
# Target should be tensor-parallel already, no further manipulation needed.
107+
target_mask = None
109108

110109
if grad_output is None:
111110
grad = None
112111
else:
113-
exp_logits1 = exp_logits.scatter(
114-
1, target, exp_logits.gather(1, target) - target_mask * sum_exp_logits.unsqueeze(dim=-1)
115-
)
116-
exp_logits2 = exp_logits1.mul((grad_output / logits.size(0)) / sum_exp_logits.unsqueeze(dim=-1))
112+
# grad / grad_output = exp_logits / sum_exp_logits - target_probabilities.
113+
if target_format == TargetFormat.labels:
114+
grad_base = exp_logits.scatter_add(
115+
1, target, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits)
116+
)
117+
else:
118+
grad_base = exp_logits - sum_exp_logits * target
119+
120+
grad = grad_base.mul((grad_output / logits.size(0)) / sum_exp_logits)
117121
if logits_scale_factor != 1.0:
118-
exp_logits2 *= logits_scale_factor
122+
grad *= logits_scale_factor
123+
grad = grad.to(logits.dtype)
124+
if loss_mask is not None:
125+
grad = torch.where(loss_mask, grad.to(logits.dtype), 0)
126+
127+
# loss = mean(log(sum_exp_logits) - sum(probabilities * logits))
128+
if target_format == TargetFormat.labels:
129+
predicted_logits = logits_norm.gather(1, target)
130+
if group is not None:
131+
predicted_logits = target_mask * predicted_logits
132+
all_reduce(predicted_logits, op=ReduceOp.SUM, group=group)
133+
else:
134+
predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True)
119135

120-
grad = torch.where(loss_mask.unsqueeze(1), exp_logits2.to(logits.dtype), 0)
136+
per_sample_loss = sum_exp_logits.log() - predicted_logits
137+
if loss_mask is not None:
138+
per_sample_loss = per_sample_loss * loss_mask
121139

122-
predicted_logits = (target_mask * logits_norm.gather(1, target)).squeeze(1)
123-
all_reduce(predicted_logits, op=ReduceOp.SUM, group=group)
124-
per_sample_loss = sum_exp_logits.log().sub(predicted_logits) * loss_mask
140+
loss = per_sample_loss.mean()
141+
if target_format != TargetFormat.labels and group is not None:
142+
all_reduce(loss, op=ReduceOp.MEAN, group=group)
125143

126-
return per_sample_loss.mean(), grad
144+
return loss, grad
127145

128146

129147
_CROSS_ENTROPY_IMPLEMENTATIONS = {
@@ -134,25 +152,32 @@ def parallel_cross_entropy_forward_backward(
134152

135153

136154
def cross_entropy_forward_backward(
137-
logits,
138-
target,
155+
logits: torch.Tensor,
156+
target: torch.Tensor,
139157
grad_output: float | None,
140-
group: ProcessGroup | None,
158+
group: ProcessGroup | None = None,
141159
implementation: CrossEntropyImpl = CrossEntropyImpl.fused,
142160
logits_scale_factor: float = 1.0,
161+
target_format: TargetFormat = TargetFormat.labels,
143162
) -> tuple[torch.Tensor, torch.Tensor | None]:
144163
"""
145164
Select the appropriate implementation of cross-entropy.
146165
The triton implementation from the triton submodule is the fastest and recommended one.
147166
It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way,
148167
which is faster and has a relatively small memory overhead.
149168
"""
169+
if target_format == TargetFormat.labels:
170+
Assert.eq(target.shape, logits.shape[:-1])
171+
Assert.eq(target.dtype, torch.int64)
172+
else:
173+
Assert.eq(target.shape, logits.shape)
174+
assert target.dtype.is_floating_point, target.dtype
150175
if group:
151176
Assert.eq(implementation, CrossEntropyImpl.fused)
152-
return parallel_cross_entropy_forward_backward(
153-
logits, target, grad_output, group, logits_scale_factor=logits_scale_factor
177+
return fused_cross_entropy_forward_backward(
178+
logits, target, grad_output, logits_scale_factor, target_format, group
154179
)
155180
else:
156181
return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation](
157-
logits, target, grad_output, logits_scale_factor=logits_scale_factor
182+
logits, target, grad_output, logits_scale_factor, target_format
158183
)

0 commit comments

Comments
 (0)