From f4f9504abd01cdf42d296cf52fc14ca43d7945c5 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Fri, 25 Sep 2020 17:58:01 +0200 Subject: [PATCH 1/2] rearrange order of function definitions in matrix multiplication This should safeguard against generated functions being generated earlier than method definitions they need. --- src/matrix_multiply.jl | 118 ------------------------------------- src/matrix_multiply_add.jl | 116 ++++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+), 118 deletions(-) diff --git a/src/matrix_multiply.jl b/src/matrix_multiply.jl index 767e5eca..8807bf7b 100644 --- a/src/matrix_multiply.jl +++ b/src/matrix_multiply.jl @@ -202,99 +202,6 @@ function mul_result_structure(::SDiagonal, ::SDiagonal) return Diagonal end -""" - uplo_access(sa, asym, k, j, uplo) - -Generate code for matrix element access, for a matrix of size `sa` locally referred to -as `asym` in the context where the result will be used. Both indices `k` and `j` need to be -statically known for this function to work. `uplo` is the access pattern mode generated -by the `gen_by_access` function. -""" -function uplo_access(sa, asym, k, j, uplo) - TAsym = Symbol("T"*string(asym)) - if uplo == :any - return :($asym[$(LinearIndices(sa)[k, j])]) - elseif uplo == :up - if k < j - return :($asym[$(LinearIndices(sa)[k, j])]) - elseif k == j - return :(LinearAlgebra.symmetric($asym[$(LinearIndices(sa)[k, j])], :U)) - else - return :(transpose($asym[$(LinearIndices(sa)[j, k])])) - end - elseif uplo == :lo - if k > j - return :($asym[$(LinearIndices(sa)[k, j])]) - elseif k == j - return :(LinearAlgebra.symmetric($asym[$(LinearIndices(sa)[k, j])], :L)) - else - return :(transpose($asym[$(LinearIndices(sa)[j, k])])) - end - elseif uplo == :up_herm - if k < j - return :($asym[$(LinearIndices(sa)[k, j])]) - elseif k == j - return :(LinearAlgebra.hermitian($asym[$(LinearIndices(sa)[k, j])], :U)) - else - return :(adjoint($asym[$(LinearIndices(sa)[j, k])])) - end - elseif uplo == :lo_herm - if k > j - return :($asym[$(LinearIndices(sa)[k, j])]) - elseif k == j - return :(LinearAlgebra.hermitian($asym[$(LinearIndices(sa)[k, j])], :L)) - else - return :(adjoint($asym[$(LinearIndices(sa)[j, k])])) - end - elseif uplo == :upper_triangular - if k <= j - return :($asym[$(LinearIndices(sa)[k, j])]) - else - return :(zero($TAsym)) - end - elseif uplo == :lower_triangular - if k >= j - return :($asym[$(LinearIndices(sa)[k, j])]) - else - return :(zero($TAsym)) - end - elseif uplo == :unit_upper_triangular - if k < j - return :($asym[$(LinearIndices(sa)[k, j])]) - elseif k == j - return :(oneunit($TAsym)) - else - return :(zero($TAsym)) - end - elseif uplo == :unit_lower_triangular - if k > j - return :($asym[$(LinearIndices(sa)[k, j])]) - elseif k == j - return :(oneunit($TAsym)) - else - return :(zero($TAsym)) - end - elseif uplo == :upper_hessenberg - if k <= j+1 - return :($asym[$(LinearIndices(sa)[k, j])]) - else - return :(zero($TAsym)) - end - elseif uplo == :transpose - return :(transpose($asym[$(LinearIndices(reverse(sa))[j, k])])) - elseif uplo == :adjoint - return :(adjoint($asym[$(LinearIndices(reverse(sa))[j, k])])) - elseif uplo == :diagonal - if k == j - return :($asym[$k]) - else - return :(zero($TAsym)) - end - else - error("Unknown uplo: $uplo") - end -end - # Implementations function mul_smat_vec_exprs(sa, access_a) @@ -369,31 +276,6 @@ for TWR in [Adjoint, Transpose, Symmetric, Hermitian, LowerTriangular, UpperTria @eval _unstatic_array(::Type{$TWR{T,TSA}}) where {S, T, N, TSA<:StaticArray{S,T,N}} = $TWR{T,<:AbstractArray{T,N}} end -function combine_products(expr_list) - filtered = filter(expr_list) do expr - if expr.head != :call || expr.args[1] != :* - error("expected call to *") - end - for arg in expr.args[2:end] - if isa(arg, Expr) && arg.head == :call && arg.args[1] == :zero - return false - end - end - return true - end - if isempty(filtered) - return :(zero(T)) - else - return reduce(filtered) do ex1, ex2 - if ex2.head != :call || ex2.args[1] != :* - error("expected call to *") - end - - return :(muladd($(ex2.args[2]), $(ex2.args[3]), $ex1)) - end - end -end - @generated function _mul(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb} S = Size(sa[1], sb[2]) # Heuristic choice for amount of codegen diff --git a/src/matrix_multiply_add.jl b/src/matrix_multiply_add.jl index e3eac95e..a33d8ba9 100644 --- a/src/matrix_multiply_add.jl +++ b/src/matrix_multiply_add.jl @@ -91,6 +91,99 @@ function check_dims(::Size{sc}, ::Size{sa}, ::Size{sb}) where {sa,sb,sc} return true end +""" + uplo_access(sa, asym, k, j, uplo) + +Generate code for matrix element access, for a matrix of size `sa` locally referred to +as `asym` in the context where the result will be used. Both indices `k` and `j` need to be +statically known for this function to work. `uplo` is the access pattern mode generated +by the `gen_by_access` function. +""" +function uplo_access(sa, asym, k, j, uplo) + TAsym = Symbol("T"*string(asym)) + if uplo == :any + return :($asym[$(LinearIndices(sa)[k, j])]) + elseif uplo == :up + if k < j + return :($asym[$(LinearIndices(sa)[k, j])]) + elseif k == j + return :(LinearAlgebra.symmetric($asym[$(LinearIndices(sa)[k, j])], :U)) + else + return :(transpose($asym[$(LinearIndices(sa)[j, k])])) + end + elseif uplo == :lo + if k > j + return :($asym[$(LinearIndices(sa)[k, j])]) + elseif k == j + return :(LinearAlgebra.symmetric($asym[$(LinearIndices(sa)[k, j])], :L)) + else + return :(transpose($asym[$(LinearIndices(sa)[j, k])])) + end + elseif uplo == :up_herm + if k < j + return :($asym[$(LinearIndices(sa)[k, j])]) + elseif k == j + return :(LinearAlgebra.hermitian($asym[$(LinearIndices(sa)[k, j])], :U)) + else + return :(adjoint($asym[$(LinearIndices(sa)[j, k])])) + end + elseif uplo == :lo_herm + if k > j + return :($asym[$(LinearIndices(sa)[k, j])]) + elseif k == j + return :(LinearAlgebra.hermitian($asym[$(LinearIndices(sa)[k, j])], :L)) + else + return :(adjoint($asym[$(LinearIndices(sa)[j, k])])) + end + elseif uplo == :upper_triangular + if k <= j + return :($asym[$(LinearIndices(sa)[k, j])]) + else + return :(zero($TAsym)) + end + elseif uplo == :lower_triangular + if k >= j + return :($asym[$(LinearIndices(sa)[k, j])]) + else + return :(zero($TAsym)) + end + elseif uplo == :unit_upper_triangular + if k < j + return :($asym[$(LinearIndices(sa)[k, j])]) + elseif k == j + return :(oneunit($TAsym)) + else + return :(zero($TAsym)) + end + elseif uplo == :unit_lower_triangular + if k > j + return :($asym[$(LinearIndices(sa)[k, j])]) + elseif k == j + return :(oneunit($TAsym)) + else + return :(zero($TAsym)) + end + elseif uplo == :upper_hessenberg + if k <= j+1 + return :($asym[$(LinearIndices(sa)[k, j])]) + else + return :(zero($TAsym)) + end + elseif uplo == :transpose + return :(transpose($asym[$(LinearIndices(reverse(sa))[j, k])])) + elseif uplo == :adjoint + return :(adjoint($asym[$(LinearIndices(reverse(sa))[j, k])])) + elseif uplo == :diagonal + if k == j + return :($asym[$k]) + else + return :(zero($TAsym)) + end + else + error("Unknown uplo: $uplo") + end +end + """ Combine left and right sides of an assignment expression, short-cutting lhs = α * rhs + β * lhs, element-wise. @@ -123,7 +216,30 @@ function _lind(var::Symbol, A::Type{TSize{sa,tA}}, k::Int, j::Int) where {sa,tA} return ula end +function combine_products(expr_list) + filtered = filter(expr_list) do expr + if expr.head != :call || expr.args[1] != :* + error("expected call to *") + end + for arg in expr.args[2:end] + if isa(arg, Expr) && arg.head == :call && arg.args[1] == :zero + return false + end + end + return true + end + if isempty(filtered) + return :(zero(T)) + else + return reduce(filtered) do ex1, ex2 + if ex2.head != :call || ex2.args[1] != :* + error("expected call to *") + end + return :(muladd($(ex2.args[2]), $(ex2.args[3]), $ex1)) + end + end +end # Matrix-vector multiplication @generated function _mul!(Sc::TSize{sc}, c::StaticVecOrMatLike, Sa::Size{sa}, Sb::Size{sb}, From 7df98ef566f955404ef4d0e804c0f4403d04db89 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Mon, 28 Sep 2020 13:13:39 +0200 Subject: [PATCH 2/2] moving gen_by_access as well --- src/matrix_multiply.jl | 144 ------------------------------------- src/matrix_multiply_add.jl | 144 +++++++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 144 deletions(-) diff --git a/src/matrix_multiply.jl b/src/matrix_multiply.jl index 8807bf7b..2b08ca4a 100644 --- a/src/matrix_multiply.jl +++ b/src/matrix_multiply.jl @@ -15,150 +15,6 @@ import LinearAlgebra: BlasFloat, matprod, mul! @inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Adjoint{<:Any,<:StaticVector}) where {N} = vec(A) * B @inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Transpose{<:Any,<:StaticVector}) where {N} = vec(A) * B -""" - gen_by_access(expr_gen, a::Type{<:AbstractArray}, asym = :wrapped_a) - -Statically generate outer code for fully unrolled multiplication loops. -Returned code does wrapper-specific tests (for example if a symmetric matrix view is -`U` or `L`) and the body of the if expression is then generated by function `expr_gen`. -The function `expr_gen` receives access pattern description symbol as its argument -and this symbol is then consumed by uplo_access to generate the right code for matrix -element access. - -The name of the matrix to test is indicated by `asym`. -""" -function gen_by_access(expr_gen, a::Type{<:StaticVecOrMat}, asym = :wrapped_a) - return expr_gen(:any) -end -function gen_by_access(expr_gen, a::Type{<:Symmetric{<:Any, <:StaticMatrix}}, asym = :wrapped_a) - return quote - if $(asym).uplo == 'U' - $(expr_gen(:up)) - else - $(expr_gen(:lo)) - end - end -end -function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, asym = :wrapped_a) - return quote - if $(asym).uplo == 'U' - $(expr_gen(:up_herm)) - else - $(expr_gen(:lo_herm)) - end - end -end -function gen_by_access(expr_gen, a::Type{<:UpperTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a) - return expr_gen(:upper_triangular) -end -function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a) - return expr_gen(:lower_triangular) -end -function gen_by_access(expr_gen, a::Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a) - return expr_gen(:unit_upper_triangular) -end -function gen_by_access(expr_gen, a::Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a) - return expr_gen(:unit_lower_triangular) -end -function gen_by_access(expr_gen, a::Type{<:Transpose{<:Any, <:StaticVecOrMat}}, asym = :wrapped_a) - return expr_gen(:transpose) -end -function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticVecOrMat}}, asym = :wrapped_a) - return expr_gen(:adjoint) -end -function gen_by_access(expr_gen, a::Type{<:SDiagonal}, asym = :wrapped_a) - return expr_gen(:diagonal) -end -""" - gen_by_access(expr_gen, a::Type{<:AbstractArray}, b::Type{<:AbstractArray}) - -Simiar to gen_by_access with only one type argument. The difference is that tests for both -arrays of type `a` and `b` are generated and `expr_gen` receives two access arguments, -first for matrix `a` and the second for matrix `b`. -""" -function gen_by_access(expr_gen, a::Type{<:StaticMatrix}, b::Type) - return quote - return $(gen_by_access(b, :wrapped_b) do access_b - expr_gen(:any, access_b) - end) - end -end -function gen_by_access(expr_gen, a::Type{<:Symmetric{<:Any, <:StaticMatrix}}, b::Type) - return quote - if wrapped_a.uplo == 'U' - return $(gen_by_access(b, :wrapped_b) do access_b - expr_gen(:up, access_b) - end) - else - return $(gen_by_access(b, :wrapped_b) do access_b - expr_gen(:lo, access_b) - end) - end - end -end -function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, b::Type) - return quote - if wrapped_a.uplo == 'U' - return $(gen_by_access(b, :wrapped_b) do access_b - expr_gen(:up_herm, access_b) - end) - else - return $(gen_by_access(b, :wrapped_b) do access_b - expr_gen(:lo_herm, access_b) - end) - end - end -end -function gen_by_access(expr_gen, a::Type{<:UpperTriangular{<:Any, <:StaticMatrix}}, b::Type) - return quote - return $(gen_by_access(b, :wrapped_b) do access_b - expr_gen(:upper_triangular, access_b) - end) - end -end -function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix}}, b::Type) - return quote - return $(gen_by_access(b, :wrapped_b) do access_b - expr_gen(:lower_triangular, access_b) - end) - end -end -function gen_by_access(expr_gen, a::Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}}, b::Type) - return quote - return $(gen_by_access(b, :wrapped_b) do access_b - expr_gen(:unit_upper_triangular, access_b) - end) - end -end -function gen_by_access(expr_gen, a::Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}}, b::Type) - return quote - return $(gen_by_access(b, :wrapped_b) do access_b - expr_gen(:unit_lower_triangular, access_b) - end) - end -end -function gen_by_access(expr_gen, a::Type{<:Transpose{<:Any, <:StaticMatrix}}, b::Type) - return quote - return $(gen_by_access(b, :wrapped_b) do access_b - expr_gen(:transpose, access_b) - end) - end -end -function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticMatrix}}, b::Type) - return quote - return $(gen_by_access(b, :wrapped_b) do access_b - expr_gen(:adjoint, access_b) - end) - end -end -function gen_by_access(expr_gen, a::Type{<:SDiagonal}, b::Type) - return quote - return $(gen_by_access(b, :wrapped_b) do access_b - expr_gen(:diagonal, access_b) - end) - end -end - """ mul_result_structure(a::Type, b::Type) diff --git a/src/matrix_multiply_add.jl b/src/matrix_multiply_add.jl index a33d8ba9..b170185a 100644 --- a/src/matrix_multiply_add.jl +++ b/src/matrix_multiply_add.jl @@ -30,6 +30,150 @@ const StaticMatMulLike{s1, s2, T} = Union{ Transpose{T, <:StaticMatrix{s1, s2, T}}, SDiagonal{s1, T}} +""" + gen_by_access(expr_gen, a::Type{<:AbstractArray}, asym = :wrapped_a) + +Statically generate outer code for fully unrolled multiplication loops. +Returned code does wrapper-specific tests (for example if a symmetric matrix view is +`U` or `L`) and the body of the if expression is then generated by function `expr_gen`. +The function `expr_gen` receives access pattern description symbol as its argument +and this symbol is then consumed by uplo_access to generate the right code for matrix +element access. + +The name of the matrix to test is indicated by `asym`. +""" +function gen_by_access(expr_gen, a::Type{<:StaticVecOrMat}, asym = :wrapped_a) + return expr_gen(:any) +end +function gen_by_access(expr_gen, a::Type{<:Symmetric{<:Any, <:StaticMatrix}}, asym = :wrapped_a) + return quote + if $(asym).uplo == 'U' + $(expr_gen(:up)) + else + $(expr_gen(:lo)) + end + end +end +function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, asym = :wrapped_a) + return quote + if $(asym).uplo == 'U' + $(expr_gen(:up_herm)) + else + $(expr_gen(:lo_herm)) + end + end +end +function gen_by_access(expr_gen, a::Type{<:UpperTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a) + return expr_gen(:upper_triangular) +end +function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a) + return expr_gen(:lower_triangular) +end +function gen_by_access(expr_gen, a::Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a) + return expr_gen(:unit_upper_triangular) +end +function gen_by_access(expr_gen, a::Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a) + return expr_gen(:unit_lower_triangular) +end +function gen_by_access(expr_gen, a::Type{<:Transpose{<:Any, <:StaticVecOrMat}}, asym = :wrapped_a) + return expr_gen(:transpose) +end +function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticVecOrMat}}, asym = :wrapped_a) + return expr_gen(:adjoint) +end +function gen_by_access(expr_gen, a::Type{<:SDiagonal}, asym = :wrapped_a) + return expr_gen(:diagonal) +end +""" + gen_by_access(expr_gen, a::Type{<:AbstractArray}, b::Type{<:AbstractArray}) + +Simiar to gen_by_access with only one type argument. The difference is that tests for both +arrays of type `a` and `b` are generated and `expr_gen` receives two access arguments, +first for matrix `a` and the second for matrix `b`. +""" +function gen_by_access(expr_gen, a::Type{<:StaticMatrix}, b::Type) + return quote + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:any, access_b) + end) + end +end +function gen_by_access(expr_gen, a::Type{<:Symmetric{<:Any, <:StaticMatrix}}, b::Type) + return quote + if wrapped_a.uplo == 'U' + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:up, access_b) + end) + else + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:lo, access_b) + end) + end + end +end +function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, b::Type) + return quote + if wrapped_a.uplo == 'U' + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:up_herm, access_b) + end) + else + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:lo_herm, access_b) + end) + end + end +end +function gen_by_access(expr_gen, a::Type{<:UpperTriangular{<:Any, <:StaticMatrix}}, b::Type) + return quote + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:upper_triangular, access_b) + end) + end +end +function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix}}, b::Type) + return quote + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:lower_triangular, access_b) + end) + end +end +function gen_by_access(expr_gen, a::Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}}, b::Type) + return quote + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:unit_upper_triangular, access_b) + end) + end +end +function gen_by_access(expr_gen, a::Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}}, b::Type) + return quote + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:unit_lower_triangular, access_b) + end) + end +end +function gen_by_access(expr_gen, a::Type{<:Transpose{<:Any, <:StaticMatrix}}, b::Type) + return quote + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:transpose, access_b) + end) + end +end +function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticMatrix}}, b::Type) + return quote + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:adjoint, access_b) + end) + end +end +function gen_by_access(expr_gen, a::Type{<:SDiagonal}, b::Type) + return quote + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:diagonal, access_b) + end) + end +end + """ Size that stores whether a Matrix is a Transpose Useful when selecting multiplication methods, and avoiding allocations when dealing with