MatrixSign.jl is a Julia package for computing the matrix sign, or more generally, the polar factor of a matrix, which is the orthogonal component of the polar decomposition. Three methods are provided:
-
SVD: uses SVD to compute the exact polar factor using
$UV^T$ . - Polar Express1: an optimized variable-step Newton-Schulz polynomial iteration with optimal coefficients to compute the polar factor more accurately.
- Keller Jordan's 5-step2: a 5-step quintic Newton-Schulz polynomial iteration to compute a noisy approximation of the polar factor.
- Newton-Schulz-based methods rely entirely on in-place generic matrix multiplication to minimize memory usage, and to be very fast on GPUs.
- Differentiable via custom chain rules.
- Fixed memory w.r.t. the number of steps (except for checkpoints in chain rule).
- Fused polynomial iteration to optimize performance for rectangular matrices (see Algorithm 4 in Polar Express1) with only slightly increased memory usage.
- Batched matrix sign for inputs with more than 2 dimensions.
The main interface is the msign function, dispatching to the PolarExpress method by default:
julia> using MatrixSign
julia> X = randn(Float64, 1024, 1024);
julia> msign(X, steps=16) ≈ msign(X, SVDMethod)
true
julia> msign(Float32.(X), steps=14) ≈ msign(Float32.(X), SVDMethod)
true
julia> @b msign($(Float32.(X)), SVDMethod)
186.368 ms (21 allocs: 28.094 MiB, 2.29% gc time)
julia> @b msign($(Float32.(X)), steps=14)
65.203 ms (20 allocs: 20.001 MiB, 0.96% gc time)For Newton-Schulz-based methods, Float16 matrices would underflow, so they are converted to BFloat16 by default, which may only perform well on supported hardware.
julia> using CUDA, BFloat16s, PrettyChairmarks
julia> X = CUDA.randn(BFloat16, 1024, 4096); # 1x4 aspect ratio
julia> @b CUDA.@sync msign(X, SVDMethod)
67.891 ms (713 allocs: 32.013 MiB, 1.19% gc time)
julia> CUDA.@allocated msign(X, SVDMethod)
71307272
julia> @b CUDA.@sync msign(X, PolarExpress, steps=8)
979.261 μs (2205 allocs: 59.531 KiB)
julia> @b CUDA.@sync msign(X, PolarExpress, steps=8, fused=3)
863.068 μs (2571 allocs: 77.516 KiB)
julia> @b CUDA.@sync msign(X, PolarExpress, steps=6)
771.161 μs (1759 allocs: 47.031 KiB)
julia> @b CUDA.@sync msign(X, PolarExpress, steps=6, fused=3)
684.425 μs (2024 allocs: 61.078 KiB)
julia> CUDA.@allocated msign(X, PolarExpress, steps=8)
29360890
julia> CUDA.@allocated msign(X, PolarExpress, steps=8, fused=3)
33555234Note that the allocations and memory usage shown next to the benchmark time under the @b calls is CPU-only, and the GPU memory usage is shown under the CUDA.@allocated calls.
- Use
PolarExpressfor most cases.- In many cases, particularly with
randn-like inputs, the majority of singular values get close to 1 in just ~5-7 steps. A few extra steps may sometimes be needed to bring up the smallest singular values, but this is not always important. - Take advantage of the
fusedkeyword to optimize performance for rectangular matrices with aspect ratios great than2.5. - For smaller data types like
BFloat16,fusedbeyond3can lead to error accumulation.
- In many cases, particularly with
- Use
JordanMethodfor parity with the original Muon2. - Use
SVDMethodfor robust, but slow comparison.
- Complex matrices are not supported.