@@ -32,54 +32,54 @@ public struct RepetitionPenaltyLogitsProcessor: LogitsProcessor {
32
32
public func callAsFunction( _ inputIds: MLTensor , _ scores: MLTensor ) async -> MLTensor {
33
33
guard penalty != 1.0 else { return scores }
34
34
35
- // Implementation approach (following transformers):
36
- // 1. Get unique token IDs from inputIds
37
- // 2. For each unique token, gather its logit value
38
- // 3. Apply conditional penalty: if logit < 0: *= penalty, else: /= penalty
39
- // 4. Scatter penalized values back to original positions
35
+ // Optimized implementation following transformers:
36
+ // 1. Gather scores for tokens that appear in input_ids
37
+ // 2. Apply conditional penalty: if score < 0: *= penalty, else: /= penalty
38
+ // 3. Scatter penalized values back to original positions
40
39
41
- // Convert to CPU for gather/scatter operations
42
- let scoresArray = await scores. shapedArray ( of: Float . self)
43
- let inputIdsArray = await inputIds. shapedArray ( of: Int32 . self)
44
-
45
- // Process each batch item
46
- var scoresData = scoresArray. scalars
47
- let shape = scores. shape
48
- precondition ( !shape. isEmpty, " scores tensor must have at least one dimension " )
49
-
50
- let batchSize = shape [ 0 ]
51
- let vocabSize = shape [ shape. count - 1 ]
52
- let elementsPerBatch = shape. dropFirst ( ) . reduce ( 1 , * )
53
- let vocabBlocksPerBatch = max ( elementsPerBatch / max( vocabSize, 1 ) , 1 )
54
-
55
- for batchIdx in 0 ..< batchSize {
56
- let batchOffset = batchIdx * elementsPerBatch
40
+ // Gather scores for tokens that appear in input_ids
41
+ let gatheredScores = scores. gathering ( atIndices: inputIds, alongAxis: - 1 )
57
42
58
- // Get unique token IDs from this sequence
59
- let seqStartIds = batchIdx * inputIds . shape [ 1 ]
60
- let seqEndIds = seqStartIds + inputIds . shape [ 1 ]
61
- let tokenIds = Set ( inputIdsArray . scalars [ seqStartIds ..< seqEndIds ] . map { Int ( $0 ) } )
43
+ // Apply conditional penalty based on sign (vectorized)
44
+ let negativeScores = gatheredScores .< 0.0
45
+ let penalizedScores = negativeScores . cast ( to : Float . self ) * ( gatheredScores * penalty ) +
46
+ ( 1.0 - negativeScores . cast ( to : Float . self ) ) * ( gatheredScores / penalty )
62
47
63
- // Apply penalty to each token that appeared in the sequence across all vocab blocks
64
- for blockIdx in 0 ..< vocabBlocksPerBatch {
65
- let blockOffset = batchOffset + blockIdx * vocabSize
48
+ // Scatter penalized values back to original positions
49
+ // Note: MLTensor doesn't have direct scatter, so we use CPU operations for this step
50
+ let vocabSize = scores. shape [ scores. rank - 1 ]
51
+ let batchSize = scores. shape [ 0 ]
66
52
67
- for tokenId in tokenIds {
68
- guard tokenId >= 0 && tokenId < vocabSize else { continue }
53
+ let inputIdsArray = await inputIds. shapedArray ( of: Int32 . self)
54
+ let penalizedArray = await penalizedScores. shapedArray ( of: Float . self)
55
+ var scoresArray = await scores. shapedArray ( of: Float . self)
69
56
70
- let scoreIdx = blockOffset + tokenId
71
- guard scoreIdx < scoresData. count else { continue }
57
+ for batchIdx in 0 ..< batchSize {
58
+ let seqStart = batchIdx * inputIds. shape [ 1 ]
59
+ let seqEnd = seqStart + inputIds. shape [ 1 ]
60
+ let batchOffset = batchIdx * scoresArray. shape. dropFirst ( ) . reduce ( 1 , * )
72
61
73
- let score = scoresData [ scoreIdx]
62
+ for (tokenIdx, inputIdxInSeq) in ( seqStart..< seqEnd) . enumerated ( ) {
63
+ let tokenId = Int ( inputIdsArray. scalars [ inputIdxInSeq] )
64
+ guard tokenId >= 0 && tokenId < vocabSize else { continue }
74
65
75
- // Apply penalty based on sign (following transformers implementation)
76
- scoresData [ scoreIdx] = score < 0 ? score * penalty : score / penalty
66
+ // For rank-2: [batch_size, vocab_size]
67
+ if scores. rank == 2 {
68
+ let scoreIdx = batchOffset + tokenId
69
+ let penalizedIdx = seqStart + tokenIdx
70
+ scoresArray. scalars [ scoreIdx] = penalizedArray. scalars [ penalizedIdx]
71
+ }
72
+ // For rank-3: [batch_size, seq_len, vocab_size] - update last position
73
+ else if scores. rank == 3 {
74
+ let lastSeqPos = scores. shape [ 1 ] - 1
75
+ let scoreIdx = batchOffset + lastSeqPos * vocabSize + tokenId
76
+ let penalizedIdx = seqStart + tokenIdx
77
+ scoresArray. scalars [ scoreIdx] = penalizedArray. scalars [ penalizedIdx]
77
78
}
78
79
}
79
80
}
80
81
81
- // Create new tensor with penalized scores
82
- return MLTensor ( shape: scores. shape, scalars: scoresData, scalarType: Float . self)
82
+ return MLTensor ( shape: scores. shape, scalars: scoresArray. scalars, scalarType: Float . self)
83
83
}
84
84
}
85
85
#endif
0 commit comments