Skip to content

Commit 28a2931

Browse files
Scheduler performance optimizations in Swift
1 parent 412c679 commit 28a2931

File tree

3 files changed

+36
-27
lines changed

3 files changed

+36
-27
lines changed

swift/StableDiffusion/pipeline/DPMSolverMultistepScheduler.swift

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,20 @@ public final class DPMSolverMultistepScheduler: Scheduler {
8383
/// Convert the model output to the corresponding type the algorithm needs.
8484
/// This implementation is for second-order DPM-Solver++ assuming epsilon prediction.
8585
func convertModelOutput(modelOutput: MLShapedArray<Float32>, timestep: Int, sample: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
86-
assert(modelOutput.scalars.count == sample.scalars.count)
86+
assert(modelOutput.scalarCount == sample.scalarCount)
87+
let scalarCount = modelOutput.scalarCount
8788
let (alpha_t, sigma_t) = (self.alpha_t[timestep], self.sigma_t[timestep])
88-
89-
// This could be optimized with a Metal kernel if we find we need to
90-
let x0_scalars = zip(modelOutput.scalars, sample.scalars).map { m, s in
91-
(s - m * sigma_t) / alpha_t
89+
90+
return MLShapedArray(unsafeUninitializedShape: modelOutput.shape) { scalars, _ in
91+
assert(scalars.count == scalarCount)
92+
modelOutput.withUnsafeShapedBufferPointer { modelOutput, _, _ in
93+
sample.withUnsafeShapedBufferPointer { sample, _, _ in
94+
for i in 0 ..< scalarCount {
95+
scalars.initializeElement(at: i, to: (sample[i] - modelOutput[i] * sigma_t) / alpha_t)
96+
}
97+
}
98+
}
9299
}
93-
return MLShapedArray(scalars: x0_scalars, shape: modelOutput.shape)
94100
}
95101

96102
/// One step for the first-order DPM-Solver (equivalent to DDIM).

swift/StableDiffusion/pipeline/Scheduler.swift

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// For licensing see accompanying LICENSE.md file.
22
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
33

4+
import Accelerate
45
import CoreML
56

67
@available(iOS 16.2, macOS 13.1, *)
@@ -59,18 +60,21 @@ public extension Scheduler {
5960
/// - values: The arrays to be weighted and summed
6061
/// - Returns: sum_i weights[i]*values[i]
6162
func weightedSum(_ weights: [Double], _ values: [MLShapedArray<Float32>]) -> MLShapedArray<Float32> {
63+
let scalarCount = values.first!.scalarCount
6264
assert(weights.count > 1 && values.count == weights.count)
63-
assert(values.allSatisfy({ $0.scalarCount == values.first!.scalarCount }))
64-
var w = Float(weights.first!)
65-
var scalars = values.first!.scalars.map({ $0 * w })
66-
for next in 1 ..< values.count {
67-
w = Float(weights[next])
68-
let nextScalars = values[next].scalars
69-
for i in 0 ..< scalars.count {
70-
scalars[i] += w * nextScalars[i]
65+
assert(values.allSatisfy({ $0.scalarCount == scalarCount }))
66+
67+
return MLShapedArray(unsafeUninitializedShape: values.first!.shape) { scalars, _ in
68+
scalars.initialize(repeating: 0.0)
69+
for i in 0 ..< values.count {
70+
let w = Float(weights[i])
71+
values[i].withUnsafeShapedBufferPointer { buffer, _, _ in
72+
assert(buffer.count == scalarCount)
73+
// scalars[j] = w * values[i].scalars[j]
74+
cblas_saxpy(Int32(scalarCount), w, buffer.baseAddress, 1, scalars.baseAddress, 1)
75+
}
7176
}
7277
}
73-
return MLShapedArray(scalars: scalars, shape: values.first!.shape)
7478
}
7579

7680
func addNoise(

swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -286,20 +286,19 @@ public struct StableDiffusionPipeline: ResourceManaging {
286286
}
287287

288288
func performGuidance(_ noise: MLShapedArray<Float32>, _ guidanceScale: Float) -> MLShapedArray<Float32> {
289-
290-
let blankNoiseScalars = noise[0].scalars
291-
let textNoiseScalars = noise[1].scalars
292-
293-
var resultScalars = blankNoiseScalars
294-
295-
for i in 0..<resultScalars.count {
296-
// unconditioned + guidance*(text - unconditioned)
297-
resultScalars[i] += guidanceScale*(textNoiseScalars[i]-blankNoiseScalars[i])
298-
}
299-
300289
var shape = noise.shape
301290
shape[0] = 1
302-
return MLShapedArray<Float32>(scalars: resultScalars, shape: shape)
291+
return MLShapedArray<Float>(unsafeUninitializedShape: shape) { result, _ in
292+
noise.withUnsafeShapedBufferPointer { scalars, _, strides in
293+
for i in 0 ..< result.count {
294+
// unconditioned + guidance*(text - unconditioned)
295+
result.initializeElement(
296+
at: i,
297+
to: scalars[i] + guidanceScale * (scalars[strides[0] + i] - scalars[i])
298+
)
299+
}
300+
}
301+
}
303302
}
304303

305304
func decodeToImages(_ latents: [MLShapedArray<Float32>], configuration config: Configuration) throws -> [CGImage?] {

0 commit comments

Comments
 (0)