Skip to content

Commit 9492d47

Browse files
committed
use MLTensor
1 parent 4d012b9 commit 9492d47

File tree

1 file changed

+37
-37
lines changed

1 file changed

+37
-37
lines changed

Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,54 +32,54 @@ public struct RepetitionPenaltyLogitsProcessor: LogitsProcessor {
3232
public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor {
3333
guard penalty != 1.0 else { return scores }
3434

35-
// Implementation approach (following transformers):
36-
// 1. Get unique token IDs from inputIds
37-
// 2. For each unique token, gather its logit value
38-
// 3. Apply conditional penalty: if logit < 0: *= penalty, else: /= penalty
39-
// 4. Scatter penalized values back to original positions
35+
// Optimized implementation following transformers:
36+
// 1. Gather scores for tokens that appear in input_ids
37+
// 2. Apply conditional penalty: if score < 0: *= penalty, else: /= penalty
38+
// 3. Scatter penalized values back to original positions
4039

41-
// Convert to CPU for gather/scatter operations
42-
let scoresArray = await scores.shapedArray(of: Float.self)
43-
let inputIdsArray = await inputIds.shapedArray(of: Int32.self)
44-
45-
// Process each batch item
46-
var scoresData = scoresArray.scalars
47-
let shape = scores.shape
48-
precondition(!shape.isEmpty, "scores tensor must have at least one dimension")
49-
50-
let batchSize = shape[0]
51-
let vocabSize = shape[shape.count - 1]
52-
let elementsPerBatch = shape.dropFirst().reduce(1, *)
53-
let vocabBlocksPerBatch = max(elementsPerBatch / max(vocabSize, 1), 1)
54-
55-
for batchIdx in 0..<batchSize {
56-
let batchOffset = batchIdx * elementsPerBatch
40+
// Gather scores for tokens that appear in input_ids
41+
let gatheredScores = scores.gathering(atIndices: inputIds, alongAxis: -1)
5742

58-
// Get unique token IDs from this sequence
59-
let seqStartIds = batchIdx * inputIds.shape[1]
60-
let seqEndIds = seqStartIds + inputIds.shape[1]
61-
let tokenIds = Set(inputIdsArray.scalars[seqStartIds..<seqEndIds].map { Int($0) })
43+
// Apply conditional penalty based on sign (vectorized)
44+
let negativeScores = gatheredScores .< 0.0
45+
let penalizedScores = negativeScores.cast(to: Float.self) * (gatheredScores * penalty) +
46+
(1.0 - negativeScores.cast(to: Float.self)) * (gatheredScores / penalty)
6247

63-
// Apply penalty to each token that appeared in the sequence across all vocab blocks
64-
for blockIdx in 0..<vocabBlocksPerBatch {
65-
let blockOffset = batchOffset + blockIdx * vocabSize
48+
// Scatter penalized values back to original positions
49+
// Note: MLTensor doesn't have direct scatter, so we use CPU operations for this step
50+
let vocabSize = scores.shape[scores.rank - 1]
51+
let batchSize = scores.shape[0]
6652

67-
for tokenId in tokenIds {
68-
guard tokenId >= 0 && tokenId < vocabSize else { continue }
53+
let inputIdsArray = await inputIds.shapedArray(of: Int32.self)
54+
let penalizedArray = await penalizedScores.shapedArray(of: Float.self)
55+
var scoresArray = await scores.shapedArray(of: Float.self)
6956

70-
let scoreIdx = blockOffset + tokenId
71-
guard scoreIdx < scoresData.count else { continue }
57+
for batchIdx in 0..<batchSize {
58+
let seqStart = batchIdx * inputIds.shape[1]
59+
let seqEnd = seqStart + inputIds.shape[1]
60+
let batchOffset = batchIdx * scoresArray.shape.dropFirst().reduce(1, *)
7261

73-
let score = scoresData[scoreIdx]
62+
for (tokenIdx, inputIdxInSeq) in (seqStart..<seqEnd).enumerated() {
63+
let tokenId = Int(inputIdsArray.scalars[inputIdxInSeq])
64+
guard tokenId >= 0 && tokenId < vocabSize else { continue }
7465

75-
// Apply penalty based on sign (following transformers implementation)
76-
scoresData[scoreIdx] = score < 0 ? score * penalty : score / penalty
66+
// For rank-2: [batch_size, vocab_size]
67+
if scores.rank == 2 {
68+
let scoreIdx = batchOffset + tokenId
69+
let penalizedIdx = seqStart + tokenIdx
70+
scoresArray.scalars[scoreIdx] = penalizedArray.scalars[penalizedIdx]
71+
}
72+
// For rank-3: [batch_size, seq_len, vocab_size] - update last position
73+
else if scores.rank == 3 {
74+
let lastSeqPos = scores.shape[1] - 1
75+
let scoreIdx = batchOffset + lastSeqPos * vocabSize + tokenId
76+
let penalizedIdx = seqStart + tokenIdx
77+
scoresArray.scalars[scoreIdx] = penalizedArray.scalars[penalizedIdx]
7778
}
7879
}
7980
}
8081

81-
// Create new tensor with penalized scores
82-
return MLTensor(shape: scores.shape, scalars: scoresData, scalarType: Float.self)
82+
return MLTensor(shape: scores.shape, scalars: scoresArray.scalars, scalarType: Float.self)
8383
}
8484
}
8585
#endif

0 commit comments

Comments
 (0)