Skip to content

Commit 30e237f

Browse files
Moelfm-fila
authored andcommitted
diff implementation on :aarch64
1 parent 6a6a6c8 commit 30e237f

File tree

3 files changed

+48
-4
lines changed

3 files changed

+48
-4
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1212
LorentzVectorHEP = "f612022c-142a-473f-8cfd-a09cf3793c6c"
1313
LorentzVectors = "3f54b04b-17fc-5cd4-9758-90c048d965e3"
1414
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
15+
SIMD = "fdea26ae-647d-5447-a871-4b548cad5224"
1516
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1617

1718
[weakdeps]
@@ -35,6 +36,7 @@ LorentzVectorHEP = "0.1.6"
3536
LorentzVectors = "0.4.3"
3637
Makie = "0.20, 0.21, 0.22"
3738
MuladdMacro = "0.2.4"
39+
SIMD = "3.7.1"
3840
StructArrays = "0.6.18, 0.7"
3941
Test = "1.9"
4042
julia = "1.9"

src/JetReconstruction.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ module JetReconstruction
1818
using LorentzVectorHEP
1919
using MuladdMacro
2020
using StructArrays
21+
using SIMD
2122

2223
# Import from LorentzVectorHEP methods for those 4-vector types
2324
pt2(p::LorentzVector) = LorentzVectorHEP.pt2(p)

src/Utils.jl

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,49 @@ array. The use of `@turbo` macro gives a significant performance boost.
146146
- `dij_min`: The minimum value in the first `n` elements of the `dij` array.
147147
- `best`: The index of the minimum value in the `dij` array.
148148
"""
149-
fast_findmin(dij, n) = begin
150-
x = @fastmath foldl(min, @view(dij[begin:n]))
151-
i = findfirst(==(x), dij)::Int
152-
x, i
149+
function fast_findmin end
150+
151+
if Sys.ARCH == :aarch64
152+
function fast_findmin(dij, n)
153+
x = @fastmath foldl(min, @view(dij[begin:n]))
154+
i = findfirst(==(x), dij)::Int
155+
x, i
156+
end
157+
else
158+
function fast_findmin(dij::DenseVector{T}, n) where {T}
159+
laneIndices = SIMD.Vec{8, Int}((1, 2, 3, 4, 5, 6, 7, 8))
160+
minvals = SIMD.Vec{8, T}(Inf)
161+
min_indices = SIMD.Vec{8, Int}(0)
162+
163+
n_batches, remainder = divrem(n, 8)
164+
lane = VecRange{8}(0)
165+
i = 1
166+
@inbounds @fastmath for _ in 1:n_batches
167+
dijs = dij[lane + i]
168+
predicate = dijs < minvals
169+
minvals = vifelse(predicate, dijs, minvals)
170+
min_indices = vifelse(predicate, laneIndices, min_indices)
171+
172+
i += 8
173+
laneIndices += 8
174+
end
175+
176+
min_value = SIMD.minimum(minvals)
177+
min_index = @inbounds min_value == minvals[1] ? min_indices[1] :
178+
min_value == minvals[2] ? min_indices[2] :
179+
min_value == minvals[3] ? min_indices[3] :
180+
min_value == minvals[4] ? min_indices[4] :
181+
min_value == minvals[5] ? min_indices[5] :
182+
min_value == minvals[6] ? min_indices[6] :
183+
min_value == minvals[7] ? min_indices[7] : min_indices[8]
184+
185+
@inbounds @fastmath for _ in 1:remainder
186+
xi = dij[i]
187+
pred = dij[i] < min_value
188+
min_value = ifelse(pred, xi, min_value)
189+
min_index = ifelse(pred, i, min_index)
190+
i += 1
191+
end
192+
return min_value, min_index
193+
end
153194
end

0 commit comments

Comments
 (0)