diff --git a/Sources/SwiftFusion/Core/TypeKeyedArrayBuffers.swift b/Sources/SwiftFusion/Core/TypeKeyedArrayBuffers.swift index 5d0a4a9a..a65be30f 100644 --- a/Sources/SwiftFusion/Core/TypeKeyedArrayBuffers.swift +++ b/Sources/SwiftFusion/Core/TypeKeyedArrayBuffers.swift @@ -87,6 +87,21 @@ extension TypeKeyedArrayBuffers { return _storage.allSatisfy { kv in other._storage[kv.key]?.count == kv.value.count } } + /// Returns a mapping from each key `k` of `self` into `bufferTransform(self[k])`. + public func mapBuffers( + _ bufferTransform: (AnyArrayBuffer) throws -> AnyArrayBuffer + ) rethrows -> TypeKeyedArrayBuffers { + try .init(_storage: _storage.mapValues(bufferTransform)) + } + + /// Returns a mapping from each key `k` of `self` into the corresponding array, transformed by + /// `bufferTransform`. + public func compactMapBuffers( + _ bufferTransform: (AnyArrayBuffer) throws -> AnyArrayBuffer? + ) rethrows -> TypeKeyedArrayBuffers { + try .init(_storage: _storage.compactMapValues(bufferTransform)) + } + /// Returns a mapping from each key `k` of `self` into `bufferTransform(self[k])`. public func mapBuffers( _ bufferTransform: (AnyArrayBuffer) throws -> AnyArrayBuffer @@ -102,6 +117,27 @@ extension TypeKeyedArrayBuffers { try .init(_storage: _storage.compactMapValues(bufferTransform)) } + /// Returns the first key in `self` such that `predicate(self[k], other[k]) == true`, or `nil` if + /// no such key exists. + /// + /// - Requires: `self.hasSameStructure(as: parameter)` + public func firstBufferKey( + homomorphicArgument other: TypeKeyedArrayBuffers, + where predicate: ( + _ myBuffer: AnyArrayBuffer, + _ otherBuffer: AnyArrayBuffer) throws -> Bool + ) rethrows -> TypeID? { + precondition( + _storage.count == other._storage.count, + "parameter must have same structure as `self`") + return try _storage.first { kv in + guard let v1 = other._storage[kv.key] else { + fatalError("parameter must have same structure as `self`") + } + return try predicate(kv.value, v1) + }.map { $0.key } + } + /// Invokes `update` on each buffer of self, passing the buffer having the same `key` in /// `parameter` as a second argument. /// diff --git a/Sources/SwiftFusion/Inference/TypeKeyedArrayBuffers+Vector.swift b/Sources/SwiftFusion/Inference/TypeKeyedArrayBuffers+Vector.swift new file mode 100644 index 00000000..4528dd2a --- /dev/null +++ b/Sources/SwiftFusion/Inference/TypeKeyedArrayBuffers+Vector.swift @@ -0,0 +1,57 @@ +// Copyright 2020 The SwiftFusion Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import _Differentiation +import PenguinStructures + +extension TypeKeyedArrayBuffers: Equatable where ElementAPI: VectorArrayDispatch { + public static func == (lhs: Self, rhs: Self) -> Bool { + lhs.firstBufferKey(homomorphicArgument: rhs) { $0 != $1 } == nil + } +} + +extension TypeKeyedArrayBuffers: AdditiveArithmetic where ElementAPI == VectorArrayDispatch { + /// Returns the vector sum of `lhs` with `rhs`, where `lhs` and `rhs` are viewed as vectors in the + /// vector space direct sum of all the variables. + /// + /// Precondition: `lhs` and `rhs` have assignments for exactly the same sets of variables. + public static func + (lhs: Self, rhs: Self) -> Self { + lhs.updatedBuffers(homomorphicArgument: rhs) { $0 + $1 } + } + + public static func += (lhs: inout Self, rhs: Self) { + lhs.updateBuffers(homomorphicArgument: rhs) { $0 += $1 } + } + + public static func - (lhs: Self, rhs: Self) -> Self { + lhs.updatedBuffers(homomorphicArgument: rhs) { $0 - $1 } + } + + public static func -= (lhs: inout Self, rhs: Self) { + lhs.updateBuffers(homomorphicArgument: rhs) { $0 -= $1 } + } + + public static var zero: Self { .init() } +} + +extension TypeKeyedArrayBuffers: Differentiable where ElementAPI: DifferentiableArrayDispatch { + public typealias TangentVector = MappedArrayBuffers + public mutating func move(along offset: TangentVector) { + updateBuffers(homomorphicArgument: offset) { $0.move(along: $1) } + } + + public var zeroTangentVectorInitializer: () -> TangentVector { + { .init() } + } +}