3
3
import torch .autograd
4
4
5
5
from fast_llm .core .distributed import ProcessGroup , ReduceOp , all_reduce
6
- from fast_llm .functional .config import CrossEntropyImpl
6
+ from fast_llm .functional .config import CrossEntropyImpl , TargetFormat
7
7
from fast_llm .functional .triton .cross_entropy import triton_cross_entropy_forward_backward
8
8
from fast_llm .utils import Assert
9
9
@@ -12,34 +12,67 @@ def torch_cross_entropy_forward_backward(
12
12
logits : torch .Tensor ,
13
13
target : torch .Tensor ,
14
14
grad_output : float | None ,
15
- logits_scale_factor : float = 1.0 ,
15
+ logits_scale_factor : float ,
16
+ target_format : TargetFormat ,
16
17
) -> tuple [torch .Tensor , torch .Tensor | None ]:
17
18
"""
18
19
A wrapper for the pytorch implementation of cross-entropy.
19
20
The cross-entropy kernels themselves are well-optimized, but the need for explicit casting
20
21
and separate forward and backward kernels lead to poor performance.
21
- TODO: loss masking only works for this method if the masking index is set to -100.
22
+ TODO: loss masking only works for with labels format and if the masking index is set to -100.
22
23
"""
23
24
# Torch compile doesn't understand this.
24
- with torch .enable_grad ():
25
- logits_ = logits .float ().detach ().requires_grad_ ()
26
- if logits_scale_factor != 1.0 :
27
- logits_ *= logits_scale_factor
25
+ with torch .set_grad_enabled (grad_output is not None ):
26
+ logits_ = logits .float ().detach ().requires_grad_ (grad_output is not None )
27
+ if target_format == TargetFormat .logits :
28
+ if logits_scale_factor != 1.0 :
29
+ target = target * logits_scale_factor
30
+ target = torch .softmax (target , dim = - 1 )
31
+ loss = torch .nn .functional .cross_entropy (
32
+ logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor , target
33
+ ).mean ()
28
34
if grad_output is None :
29
- loss = None
35
+ grad = None
30
36
else :
31
- loss = torch .nn .functional .cross_entropy (logits_ , target ).mean ()
32
37
loss .backward (torch .full_like (loss , grad_output ))
33
- loss .detach_ ()
34
- return loss .detach (), logits_ .grad .detach ().to (logits .dtype )
38
+ grad = logits_ .grad .detach ().to (logits .dtype )
39
+ return loss .detach_ (), grad
40
+
41
+
42
+ # @torch.compile
43
+ def _fused_softmax_base (
44
+ logits : torch .Tensor , logits_scale_factor : float = 1.0 , group : ProcessGroup | None = None , dim : int = - 1
45
+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
46
+ logits = logits .float ()
47
+ if logits_scale_factor != 1.0 :
48
+ logits *= logits_scale_factor
49
+ logits_max = torch .max (logits , dim = dim , keepdim = True )[0 ]
50
+ if group is not None :
51
+ all_reduce (logits_max , op = ReduceOp .MAX , group = group )
52
+ logits_norm = (logits - logits_max ).float ()
53
+ exp_logits = logits_norm .exp ()
54
+ sum_exp_logits = exp_logits .sum (dim = dim , keepdim = True )
55
+ if group is not None :
56
+ all_reduce (sum_exp_logits , op = ReduceOp .SUM , group = group )
57
+ return logits_norm , exp_logits , sum_exp_logits
58
+
59
+
60
+ # @torch.compile
61
+ def fused_softmax (
62
+ logits : torch .Tensor , logits_scale_factor : float = 1.0 , group : ProcessGroup = None , dim : int = - 1
63
+ ) -> torch .Tensor :
64
+ _ , exp_logits , sum_exp_logits = _fused_softmax_base (logits , logits_scale_factor , group , dim )
65
+ return exp_logits / sum_exp_logits
35
66
36
67
37
68
@torch .compile
38
69
def fused_cross_entropy_forward_backward (
39
70
logits : torch .Tensor ,
40
71
target : torch .Tensor ,
41
72
grad_output : float | None ,
42
- logits_scale_factor : float = 1.0 ,
73
+ logits_scale_factor : float ,
74
+ target_format : TargetFormat ,
75
+ group : ProcessGroup | None = None ,
43
76
) -> tuple [torch .Tensor , torch .Tensor | None ]:
44
77
"""
45
78
A fused implementation of cross-entropy with torch compile.
@@ -48,82 +81,67 @@ def fused_cross_entropy_forward_backward(
48
81
"""
49
82
# Do the forward and backward passes all at once, and fused with dtype conversion.
50
83
# Way faster and more memory-efficient than the pytorch version.
51
- loss_mask = target >= 0
52
- # Ignore_index can go out of bounds, so set masked values to zero.
53
- target = (target * loss_mask ).unsqueeze (1 )
54
- logits_norm = logits .sub (torch .max (logits , dim = - 1 )[0 ].unsqueeze (dim = - 1 )).float ()
55
- if logits_scale_factor != 1.0 :
56
- logits_norm *= logits_scale_factor
57
- exp_logits = logits_norm .exp ()
58
- sum_exp_logits = exp_logits .sum (dim = - 1 )
59
-
60
- if grad_output is None :
61
- grad = None
62
- else :
63
- exp_logits = exp_logits .scatter (1 , target , exp_logits .gather (1 , target ) - sum_exp_logits .unsqueeze (dim = - 1 ))
64
- # exp_logits[torch.arange(0, logits.size(0), device=logits.device), target.squeeze(dim=-1)]-=sum_exp_logits
65
- exp_logits = exp_logits .mul ((grad_output / logits .size (0 )) / sum_exp_logits .unsqueeze (dim = - 1 ))
66
-
67
- if logits_scale_factor != 1.0 :
68
- exp_logits *= logits_scale_factor
69
-
70
- grad = torch .where (loss_mask .unsqueeze (1 ), exp_logits .to (logits .dtype ), 0 )
71
-
72
- per_sample_loss = sum_exp_logits .log ().sub (logits_norm .gather (1 , target ).squeeze (1 )) * loss_mask
73
-
74
- return per_sample_loss .mean (), grad
75
84
85
+ logits_norm , exp_logits , sum_exp_logits = _fused_softmax_base (logits , logits_scale_factor , group )
76
86
77
- @torch .compile
78
- def parallel_cross_entropy_forward_backward (
79
- logits : torch .Tensor ,
80
- target : torch .Tensor ,
81
- grad_output : float | None ,
82
- group : ProcessGroup ,
83
- logits_scale_factor : float = 1.0 ,
84
- ) -> tuple [torch .Tensor , torch .Tensor | None ]:
85
- """
86
- A fused implementation of cross-entropy with torch compile, with support for tensor parallelism.
87
- Comes with a noticeable overhead, but reduces memory usage.
88
- """
89
- # TODO: Compiled version incorrect for some inputs (32 bit indexing issue?).
90
- # TODO: Optimize, overlap/combine reductions
91
- loss_mask = target >= 0
92
- target = target .unsqueeze (1 )
93
-
94
- logits_max = torch .max (logits , dim = - 1 )[0 ]
95
- all_reduce (logits_max , op = ReduceOp .MAX , group = group )
96
- logits_norm = logits .sub (logits_max .unsqueeze (dim = - 1 )).float ()
97
- if logits_scale_factor != 1.0 :
98
- logits_norm *= logits_scale_factor
99
-
100
- exp_logits = logits_norm .exp ()
101
- sum_exp_logits = exp_logits .sum (dim = - 1 )
102
- all_reduce (sum_exp_logits , op = ReduceOp .SUM , group = group )
87
+ if target_format == TargetFormat .logits :
88
+ target = fused_softmax (target , logits_scale_factor , group )
103
89
104
- # Mask the target (fused)
105
- # TODO: Could mask earlier on cpu or overlap with reduce?
106
- vocab_start_index = logits .size (- 1 ) * group .rank ()
107
- target_mask = (target >= vocab_start_index ) * (target < vocab_start_index + logits .size (- 1 ))
108
- target = (target - vocab_start_index ) * target_mask
90
+ if target_format == TargetFormat .labels :
91
+ target = target .unsqueeze (- 1 )
92
+ loss_mask = target >= 0
93
+ if group is None :
94
+ # Keep values within range for scatter and gather ops to work.
95
+ target = target * loss_mask
96
+ target_mask = None
97
+ else :
98
+ # Mask the target (fused)
99
+ # TODO: Could mask earlier on cpu or overlap with reduce?
100
+ vocab_start_index = logits .size (- 1 ) * group .rank ()
101
+ target_mask = (target >= vocab_start_index ) * (target < vocab_start_index + logits .size (- 1 ))
102
+ target = (target - vocab_start_index ) * target_mask
103
+ else :
104
+ # TODO: Support masking
105
+ loss_mask = None
106
+ # Target should be tensor-parallel already, no further manipulation needed.
107
+ target_mask = None
109
108
110
109
if grad_output is None :
111
110
grad = None
112
111
else :
113
- exp_logits1 = exp_logits .scatter (
114
- 1 , target , exp_logits .gather (1 , target ) - target_mask * sum_exp_logits .unsqueeze (dim = - 1 )
115
- )
116
- exp_logits2 = exp_logits1 .mul ((grad_output / logits .size (0 )) / sum_exp_logits .unsqueeze (dim = - 1 ))
112
+ # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities.
113
+ if target_format == TargetFormat .labels :
114
+ grad_base = exp_logits .scatter_add (
115
+ 1 , target , - sum_exp_logits if target_mask is None else - (target_mask * sum_exp_logits )
116
+ )
117
+ else :
118
+ grad_base = exp_logits - sum_exp_logits * target
119
+
120
+ grad = grad_base .mul ((grad_output / logits .size (0 )) / sum_exp_logits )
117
121
if logits_scale_factor != 1.0 :
118
- exp_logits2 *= logits_scale_factor
122
+ grad *= logits_scale_factor
123
+ grad = grad .to (logits .dtype )
124
+ if loss_mask is not None :
125
+ grad = torch .where (loss_mask , grad .to (logits .dtype ), 0 )
126
+
127
+ # loss = mean(log(sum_exp_logits) - sum(probabilities * logits))
128
+ if target_format == TargetFormat .labels :
129
+ predicted_logits = logits_norm .gather (1 , target )
130
+ if group is not None :
131
+ predicted_logits = target_mask * predicted_logits
132
+ all_reduce (predicted_logits , op = ReduceOp .SUM , group = group )
133
+ else :
134
+ predicted_logits = (target * logits_norm ).sum (dim = - 1 , keepdim = True )
119
135
120
- grad = torch .where (loss_mask .unsqueeze (1 ), exp_logits2 .to (logits .dtype ), 0 )
136
+ per_sample_loss = sum_exp_logits .log () - predicted_logits
137
+ if loss_mask is not None :
138
+ per_sample_loss = per_sample_loss * loss_mask
121
139
122
- predicted_logits = ( target_mask * logits_norm . gather ( 1 , target )). squeeze ( 1 )
123
- all_reduce ( predicted_logits , op = ReduceOp . SUM , group = group )
124
- per_sample_loss = sum_exp_logits . log (). sub ( predicted_logits ) * loss_mask
140
+ loss = per_sample_loss . mean ( )
141
+ if target_format != TargetFormat . labels and group is not None :
142
+ all_reduce ( loss , op = ReduceOp . MEAN , group = group )
125
143
126
- return per_sample_loss . mean () , grad
144
+ return loss , grad
127
145
128
146
129
147
_CROSS_ENTROPY_IMPLEMENTATIONS = {
@@ -134,25 +152,32 @@ def parallel_cross_entropy_forward_backward(
134
152
135
153
136
154
def cross_entropy_forward_backward (
137
- logits ,
138
- target ,
155
+ logits : torch . Tensor ,
156
+ target : torch . Tensor ,
139
157
grad_output : float | None ,
140
- group : ProcessGroup | None ,
158
+ group : ProcessGroup | None = None ,
141
159
implementation : CrossEntropyImpl = CrossEntropyImpl .fused ,
142
160
logits_scale_factor : float = 1.0 ,
161
+ target_format : TargetFormat = TargetFormat .labels ,
143
162
) -> tuple [torch .Tensor , torch .Tensor | None ]:
144
163
"""
145
164
Select the appropriate implementation of cross-entropy.
146
165
The triton implementation from the triton submodule is the fastest and recommended one.
147
166
It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way,
148
167
which is faster and has a relatively small memory overhead.
149
168
"""
169
+ if target_format == TargetFormat .labels :
170
+ Assert .eq (target .shape , logits .shape [:- 1 ])
171
+ Assert .eq (target .dtype , torch .int64 )
172
+ else :
173
+ Assert .eq (target .shape , logits .shape )
174
+ assert target .dtype .is_floating_point , target .dtype
150
175
if group :
151
176
Assert .eq (implementation , CrossEntropyImpl .fused )
152
- return parallel_cross_entropy_forward_backward (
153
- logits , target , grad_output , group , logits_scale_factor = logits_scale_factor
177
+ return fused_cross_entropy_forward_backward (
178
+ logits , target , grad_output , logits_scale_factor , target_format , group
154
179
)
155
180
else :
156
181
return _CROSS_ENTROPY_IMPLEMENTATIONS [implementation ](
157
- logits , target , grad_output , logits_scale_factor = logits_scale_factor
182
+ logits , target , grad_output , logits_scale_factor , target_format
158
183
)
0 commit comments