Skip to content

Commit d50f9da

Browse files
committed
improve latency of matrix-exp
1 parent 4225f1b commit d50f9da

File tree

1 file changed

+47
-46
lines changed

1 file changed

+47
-46
lines changed

src/expm.jl

+47-46
Original file line numberDiff line numberDiff line change
@@ -77,53 +77,54 @@ function _exp(::Size, _A::StaticMatrix{<:Any,<:Any,T}) where T
7777
A = S.(_A)
7878
# omitted: matrix balancing, i.e., LAPACK.gebal!
7979
nA = maximum(sum(abs.(A); dims=Val(1))) # marginally more performant than norm(A, 1)
80-
## For sufficiently small nA, use lower order Padé-Approximations
81-
if (nA <= 2.1)
82-
A2 = A*A
83-
if nA > 0.95
84-
U = @evalpoly(A2, S(8821612800)*I, S(302702400)*I, S(2162160)*I, S(3960)*I, S(1)*I)
85-
U = A*U
86-
V = @evalpoly(A2, S(17643225600)*I, S(2075673600)*I, S(30270240)*I, S(110880)*I, S(90)*I)
87-
elseif nA > 0.25
88-
U = @evalpoly(A2, S(8648640)*I, S(277200)*I, S(1512)*I, S(1)*I)
89-
U = A*U
90-
V = @evalpoly(A2, S(17297280)*I, S(1995840)*I, S(25200)*I, S(56)*I)
91-
elseif nA > 0.015
92-
U = @evalpoly(A2, S(15120)*I, S(420)*I, S(1)*I)
93-
U = A*U
94-
V = @evalpoly(A2, S(30240)*I, S(3360)*I, S(30)*I)
95-
else
96-
U = @evalpoly(A2, S(60)*I, S(1)*I)
97-
U = A*U
98-
V = @evalpoly(A2, S(120)*I, S(12)*I)
99-
end
100-
expA = (V - U) \ (V + U)
80+
81+
if (nA 2.1) # for sufficiently small nA, use lower order Padé-Approximations
82+
return _pade_exp(S, A, nA)
10183
else
102-
s = log2(nA/5.4) # power of 2 later reversed by squaring
103-
if s > 0
104-
si = ceil(Int,s)
105-
A = A / S(2^si)
106-
end
107-
108-
A2 = A*A
109-
A4 = A2*A2
110-
A6 = A2*A4
111-
112-
U = A6*(S(1)*A6 + S(16380)*A4 + S(40840800)*A2) +
113-
(S(33522128640)*A6 + S(10559470521600)*A4 + S(1187353796428800)*A2) +
114-
S(32382376266240000)*I
115-
U = A*U
116-
V = A6*(S(182)*A6 + S(960960)*A4 + S(1323241920)*A2) +
117-
(S(670442572800)*A6 + S(129060195264000)*A4 + S(7771770303897600)*A2) +
118-
S(64764752532480000)*I
119-
expA = (V - U) \ (V + U)
120-
121-
if s > 0 # squaring to reverse dividing by power of 2
122-
for t=1:si
123-
expA = expA*expA
124-
end
125-
end
84+
return _rescaled_exp(S, A, nA)
12685
end
86+
end
12787

128-
expA
88+
function _pade_exp(S, A, nA)
89+
A2 = A*A
90+
U, V = if nA > 0.95
91+
@evalpoly(A2, S(8821612800)*I, S(302702400)*I, S(2162160)*I, S(3960)*I, S(1)*I),
92+
@evalpoly(A2, S(17643225600)*I, S(2075673600)*I, S(30270240)*I, S(110880)*I, S(90)*I)
93+
elseif nA > 0.25
94+
@evalpoly(A2, S(8648640)*I, S(277200)*I, S(1512)*I, S(1)*I),
95+
@evalpoly(A2, S(17297280)*I, S(1995840)*I, S(25200)*I, S(56)*I)
96+
elseif nA > 0.015
97+
@evalpoly(A2, S(15120)*I, S(420)*I, S(1)*I),
98+
@evalpoly(A2, S(30240)*I, S(3360)*I, S(30)*I)
99+
else
100+
@evalpoly(A2, S(60)*I, S(1)*I),
101+
@evalpoly(A2, S(120)*I, S(12)*I)
102+
end
103+
U = A*U
104+
return (V - U) \ (V + U)
129105
end
106+
107+
function _rescaled_exp(S, A, nA)
108+
si = ceil(Int, log2(nA/5.4)) # power of 2 later reversed by squaring
109+
if si > 0
110+
A /= S(2^si)
111+
end
112+
113+
A2 = A*A
114+
A4 = A2*A2
115+
A6 = A2*A4
116+
117+
U = A6*(S(1)*A6 + S(16380)*A4 + S(40840800)*A2) +
118+
(S(33522128640)*A6 + S(10559470521600)*A4 + S(1187353796428800)*A2) +
119+
S(32382376266240000)*I
120+
U = A*U
121+
V = A6*(S(182)*A6 + S(960960)*A4 + S(1323241920)*A2) +
122+
(S(670442572800)*A6 + S(129060195264000)*A4 + S(7771770303897600)*A2) +
123+
S(64764752532480000)*I
124+
expA = (V - U) \ (V + U)
125+
126+
for _ in 1:si # squaring to reverse dividing by power of 2
127+
expA *= expA
128+
end
129+
return expA
130+
end

0 commit comments

Comments
 (0)