Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .swift-format
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"AvoidRetroactiveConformances": true,
"BeginDocumentationCommentWithOneLineSummary": false,
"DoNotUseSemicolons": false,
"DontRepeatTypeInStaticProperties": true,
"DontRepeatTypeInStaticProperties": false,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strange that the lint check fails even when this is disabled. FWIW, I also can't reproduce the lint failure locally...

Sources/Models/LanguageModel.swift:132:20: warning: [DontRepeatTypeInStaticProperties] remove the suffix 'Keys' from the name of the variable 'presentKeys'

"FileScopedDeclarationPrivacy": true,
"FullyIndirectEnum": true,
"GroupNumericLiterals": false,
Expand Down
1 change: 1 addition & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ let package = Package(
.target(name: "Hub", dependencies: [.product(name: "Jinja", package: "swift-jinja")], resources: [.process("Resources")], swiftSettings: swiftSettings),
.target(name: "Models", dependencies: ["Tokenizers", "Generation"]),
.target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")]),
.testTarget(name: "GenerationTests", dependencies: ["Generation"]),
.testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")], swiftSettings: swiftSettings),
.testTarget(name: "ModelsTests", dependencies: ["Models", "Hub"], resources: [.process("Resources")]),
.testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources")]),
Expand Down
63 changes: 62 additions & 1 deletion Sources/Generation/Decoders.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func selectNextTokenUsingGreedyDecoding(from scores: MLTensor) -> MLTensor {

@available(macOS 15.0, iOS 18.0, *)
func selectNextTokenUsingTopKSampling(from scores: MLTensor, temperature: Float, topK: Int) -> MLTensor {
let temperatureAdjustedScores = scores / temperature
let temperatureAdjustedScores = temperature == 1.0 ? scores : scores / temperature
let (topKScores, topKIndices) = temperatureAdjustedScores.topK(topK)
let topKProbs = topKScores.softmax(alongAxis: -1)
let rnd = topKProbs.sum() * Float.random(in: 0..<1)
Expand All @@ -25,4 +25,65 @@ func selectNextTokenUsingTopKSampling(from scores: MLTensor, temperature: Float,
)
return nextTokenTensor.reshaped(to: [1, 1])
}

// MARK: Top-P (Nucleus) Sampling

/// Selects the next token using top-p (nucleus) sampling.
///
/// Top-p sampling dynamically selects from the smallest possible set of words
/// whose cumulative probability exceeds the probability p. This provides more
/// diversity than top-k by adapting the vocabulary size based on the probability
/// distribution.
@available(macOS 15.0, iOS 18.0, *)
func selectNextTokenUsingTopPSampling(from scores: MLTensor, temperature: Float, topP: Double) -> MLTensor {
let temperatureAdjustedScores = temperature == 1.0 ? scores : scores / temperature
let probs = temperatureAdjustedScores.softmax(alongAxis: -1)

// Sort probabilities in descending order by negating values first
let negatedProbs = -probs
let sortedIndices = negatedProbs.argsort(alongAxis: -1)
let sortedProbs = probs.gathering(atIndices: sortedIndices, alongAxis: -1)

// Calculate cumulative sum
let cumProbs = sortedProbs.cumulativeSum(alongAxis: -1)

// Find cutoff point - keep tokens where cumulative probability <= topP
let cutoffMask = cumProbs .<= Float(topP)

// Always keep at least the first (highest probability) token
let firstToken = MLTensor(repeating: 1.0, shape: Array(cutoffMask.shape.dropLast()) + [1])
if cutoffMask.shape.last! > 1 {
let restMask = cutoffMask[..., 1...]
let finalMask = MLTensor(concatenating: [firstToken, restMask], alongAxis: -1)

// Apply mask to sorted probabilities
let maskedSortedProbs = finalMask * sortedProbs

// Sample from the masked distribution
let totalMaskedProb = maskedSortedProbs.sum(alongAxes: [-1]).expandingShape(at: -1)
let normalizedProbs = maskedSortedProbs / totalMaskedProb

let rnd = Float.random(in: 0..<1)
let cumMaskedProbs = normalizedProbs.cumulativeSum(alongAxis: -1)
var accumProbs = cumMaskedProbs
accumProbs += (accumProbs .< rnd) * 100.0
let selectedIdx = accumProbs.argsort()[..., 0]

let nextTokenTensor = sortedIndices.gathering(
atIndices: selectedIdx,
alongAxis: sortedIndices.rank - 1
)

return nextTokenTensor.reshaped(to: [1, 1])
} else {
// Only one token, just return it
let selectedIdx = MLTensor([Int32(0)])
let nextTokenTensor = sortedIndices.gathering(
atIndices: selectedIdx,
alongAxis: sortedIndices.rank - 1
)
return nextTokenTensor.reshaped(to: [1, 1])
}
}

#endif // canImport(CoreML)
56 changes: 50 additions & 6 deletions Sources/Generation/Generation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,31 @@ extension Generation {
var outputTokens = MLTensor(tokens).expandingShape(at: 0)
while outputTokens.shape[1] < config.maxLength {
let nextTokenScores = await model(outputTokens, config)
// Apply logits processors (repetition penalty, etc.)
let processedLogits = applyLogitsProcessors(
inputIds: outputTokens,
logits: nextTokenScores,
config: config
)

let nextToken =
switch config.generationMode {
case .greedy:
selectNextTokenUsingGreedyDecoding(from: nextTokenScores)
selectNextTokenUsingGreedyDecoding(from: processedLogits)
case .sample:
selectNextTokenUsingTopKSampling(
from: nextTokenScores,
temperature: config.temperature,
topK: config.topK
)
if config.topP < 1.0 {
selectNextTokenUsingTopPSampling(
from: processedLogits,
temperature: config.temperature,
topP: config.topP
)
} else {
selectNextTokenUsingTopKSampling(
from: processedLogits,
temperature: config.temperature,
topK: config.topK
)
}
default:
fatalError("Generation mode \(config.generationMode) not implemented yet")
}
Expand All @@ -104,6 +119,35 @@ extension Generation {
private func tensorToGenerationOutput(_ tensor: MLTensor) async -> GenerationOutput {
await tensor.shapedArray(of: Int32.self).scalars.map { Int($0) }
}

/// Applies configured logits processors to the raw logits.
///
/// - Parameters:
/// - inputIds: The input token sequence
/// - logits: Raw logits from the model
/// - config: Generation configuration with processor settings
/// - Returns: Processed logits
private func applyLogitsProcessors(inputIds: MLTensor, logits: MLTensor, config: GenerationConfig) -> MLTensor {
var warpers: [LogitsWarper] = []

// Add temperature warper if temperature is not 1.0
if config.temperature != 1.0 {
warpers.append(TemperatureLogitsWarper(temperature: Double(config.temperature)))
}

// Add repetition penalty if configured
if config.repetitionPenalty != 1.0 {
warpers.append(RepetitionPenaltyWarper(penalty: config.repetitionPenalty))
}

// Apply all warpers if any are configured
if warpers.isEmpty {
return logits
}

let processor = LogitsProcessor(warpers: warpers)
return processor.process(inputIds: inputIds, logits: logits)
}
}

@available(macOS 15.0, iOS 18.0, *)
Expand Down
63 changes: 63 additions & 0 deletions Sources/Generation/LogitsWarper.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#if canImport(CoreML)
import CoreML

/// Protocol for modifying logits before token sampling.
///
/// Logits warpers can be used to apply various transformations to the logits
/// distribution before sampling, such as temperature scaling, top-k filtering,
/// top-p (nucleus) filtering, or repetition penalties.
@available(macOS 15.0, iOS 18.0, *)
public protocol LogitsWarper {
/// Warps (modifies) the logits before sampling.
///
/// - Parameters:
/// - inputIds: The input token sequence used for context-dependent warping
/// - logits: The logits tensor to be modified
/// - Returns: The modified logits tensor
func warp(inputIds: MLTensor, logits: MLTensor) -> MLTensor

/// Alternative call syntax for convenience.
func callAsFunction(inputIds: MLTensor, logits: MLTensor) -> MLTensor
}

@available(macOS 15.0, iOS 18.0, *)
public extension LogitsWarper {
/// Default implementation of callAsFunction that delegates to warp.
func callAsFunction(inputIds: MLTensor, logits: MLTensor) -> MLTensor {
warp(inputIds: inputIds, logits: logits)
}
}

/// A collection of logits warpers that processes logits sequentially.
@available(macOS 15.0, iOS 18.0, *)
public struct LogitsProcessor {
private let warpers: [LogitsWarper]

/// Creates a new logits processor with the specified warpers.
///
/// - Parameter warpers: Array of logits warpers to apply sequentially
public init(warpers: [LogitsWarper] = []) {
self.warpers = warpers
}

/// Applies all warpers sequentially to the logits.
///
/// - Parameters:
/// - inputIds: The input token sequence
/// - logits: The logits tensor to process
/// - Returns: The processed logits tensor
public func process(inputIds: MLTensor, logits: MLTensor) -> MLTensor {
var processedLogits = logits
for warper in warpers {
processedLogits = warper.warp(inputIds: inputIds, logits: processedLogits)
}
return processedLogits
}

/// Alternative call syntax for convenience.
public func callAsFunction(inputIds: MLTensor, logits: MLTensor) -> MLTensor {
process(inputIds: inputIds, logits: logits)
}
}

#endif // canImport(CoreML)
44 changes: 44 additions & 0 deletions Sources/Generation/RepetitionPenaltyWarper.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#if canImport(CoreML)
import CoreML

/// Logits warper that applies repetition penalty.
///
/// Repetition penalty reduces the likelihood of generating tokens that have
/// already appeared in the input sequence. This helps reduce repetitive text
/// generation.
///
/// - Note: Penalty > 1.0 penalizes repetition, penalty < 1.0 encourages it
@available(macOS 15.0, iOS 18.0, *)
public struct RepetitionPenaltyWarper: LogitsWarper {
/// The repetition penalty factor.
public let penalty: Float

/// Creates a new repetition penalty warper.
///
/// - Parameter penalty: Penalty factor (must be > 0). Values > 1.0 penalize repetition.
public init(penalty: Double) {
precondition(penalty > 0, "Penalty must be strictly positive")
self.penalty = Float(penalty)
}

/// Applies repetition penalty to tokens that appear in the input sequence.
///
/// - Parameters:
/// - inputIds: The input token sequence used to identify repeated tokens
/// - logits: The logits tensor to modify
/// - Returns: Logits with repetition penalty applied
public func warp(inputIds: MLTensor, logits: MLTensor) -> MLTensor {
if penalty == 1.0 {
return logits
}

// TODO: Implement repetition penalty when MLTensor API allows for easier tensor updates
// For now, we'll return the original logits to avoid compilation errors
// This functionality will need to be implemented when tensor item access and update operations are available

print("Warning: Repetition penalty is not yet implemented due to MLTensor API limitations")
return logits
}
}

#endif // canImport(CoreML)
37 changes: 37 additions & 0 deletions Sources/Generation/TemperatureLogitsWarper.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#if canImport(CoreML)
import CoreML

/// Logits warper that applies temperature scaling.
///
/// Temperature scaling modifies the sharpness of the probability distribution:
/// - Temperature < 1.0: Makes the distribution more concentrated (less random)
/// - Temperature = 1.0: No change to the distribution
/// - Temperature > 1.0: Makes the distribution more uniform (more random)
@available(macOS 15.0, iOS 18.0, *)
public struct TemperatureLogitsWarper: LogitsWarper {
/// The temperature value for scaling logits.
public let temperature: Float

/// Creates a new temperature logits warper.
///
/// - Parameter temperature: Temperature value (must be > 0)
public init(temperature: Double) {
precondition(temperature > 0, "Temperature must be strictly positive")
self.temperature = Float(temperature)
}

/// Applies temperature scaling to the logits.
///
/// - Parameters:
/// - inputIds: The input token sequence (unused by temperature warper)
/// - logits: The logits tensor to scale
/// - Returns: Temperature-scaled logits
public func warp(inputIds: MLTensor, logits: MLTensor) -> MLTensor {
if temperature == 1.0 {
return logits
}
return logits / temperature
}
}

#endif // canImport(CoreML)
1 change: 1 addition & 0 deletions Sources/Models/LanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ extension LanguageModel {
static let valueCache = "valueCache"
// Output keys
static let logits = "logits"
// swift-format-ignore: DontRepeatTypeInStaticProperties
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lint check failing despite disabling rule globally in .swift-format and using this swift-format-ignore: directive...

static let presentKeys = "presentKeys"
static let presentValues = "presentValues"
}
Expand Down
70 changes: 70 additions & 0 deletions Tests/GenerationTests/LogitsWarperTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import CoreML
import Testing

@testable import Generation

#if canImport(CoreML)
@Suite("Logits Warper Tests")
struct LogitsWarperTests {

@Test("Temperature warper scaling")
@available(macOS 15.0, iOS 18.0, *)
func temperatureWarper() {
let logits = MLTensor([[1.0, 2.0, 3.0]])
let inputIds = MLTensor([[1, 2]])

let tempWarper = TemperatureLogitsWarper(temperature: 2.0)
let warpedLogits = tempWarper.warp(inputIds: inputIds, logits: logits)

#expect(warpedLogits.shape == logits.shape)

let identityWarper = TemperatureLogitsWarper(temperature: 1.0)
let unchangedLogits = identityWarper.warp(inputIds: inputIds, logits: logits)
#expect(unchangedLogits.shape == logits.shape)
}

@Test("LogitsProcessor with multiple warpers")
@available(macOS 15.0, iOS 18.0, *)
func logitsProcessor() {
let logits = MLTensor([[1.0, 2.0, 3.0]])
let inputIds = MLTensor([[1, 2]])

let warpers: [LogitsWarper] = [
TemperatureLogitsWarper(temperature: 2.0)
]

let processor = LogitsProcessor(warpers: warpers)
let processedLogits = processor.process(inputIds: inputIds, logits: logits)

#expect(processedLogits.shape == logits.shape)
}

@Test("LogitsProcessor with no warpers")
@available(macOS 15.0, iOS 18.0, *)
func logitsProcessorEmpty() {
let logits = MLTensor([[1.0, 2.0, 3.0]])
let inputIds = MLTensor([[1, 2]])

let processor = LogitsProcessor(warpers: [])
let processedLogits = processor.process(inputIds: inputIds, logits: logits)

#expect(processedLogits.shape == logits.shape)
}

@Test("Repetition penalty warper")
@available(macOS 15.0, iOS 18.0, *)
func repetitionPenaltyWarper() {
let logits = MLTensor([[1.0, 2.0, 3.0]])
let inputIds = MLTensor([[0, 1]])

let repWarper = RepetitionPenaltyWarper(penalty: 1.2)
let warpedLogits = repWarper.warp(inputIds: inputIds, logits: logits)

#expect(warpedLogits.shape == logits.shape)

let identityWarper = RepetitionPenaltyWarper(penalty: 1.0)
let unchangedLogits = identityWarper.warp(inputIds: inputIds, logits: logits)
#expect(unchangedLogits.shape == logits.shape)
}
}
#endif
Loading
Loading