diff --git a/docs/src/misc.md b/docs/src/misc.md index 235997724a..3bb961a261 100644 --- a/docs/src/misc.md +++ b/docs/src/misc.md @@ -202,7 +202,7 @@ The available algorithms/normal forms are To select a normal form type for rings of type `NewRing`, implement the function ```julia -Solve.matrix_normal_form_type(::NewRing) = Bla() +Solve.matrix_normal_form_type(::Type{<:NewRing}) = Bla() ``` where `Bla <: MatrixNormalFormTrait`. A new type trait can be added via @@ -262,8 +262,8 @@ object. First of all, one needs to implement the function ```julia -function Solve.solve_context_type(::NewRing) - return Solve.solve_context_type(::NormalFormTrait, elem_type(NewRing)) +function Solve.solve_context_type(T::Type{<:NewRing}) + return Solve.solve_context_type(::NormalFormTrait, elem_type(T)) end ``` diff --git a/src/Solve.jl b/src/Solve.jl index de9147fe85..b3c206f5fa 100644 --- a/src/Solve.jl +++ b/src/Solve.jl @@ -88,22 +88,25 @@ struct LUTrait <: MatrixNormalFormTrait end # LU factoring for fields struct FFLUTrait <: MatrixNormalFormTrait end # "fraction free" LU factoring for fraction fields struct MatrixInterpolateTrait <: MatrixNormalFormTrait end # interpolate in fraction fields of polynomial rings -function matrix_normal_form_type(R::Ring) - if is_domain_type(elem_type(R)) +function matrix_normal_form_type(T::Type{<:Ring}) + if is_domain_type(T) return HermiteFormTrait() else return HowellFormTrait() end end -matrix_normal_form_type(::Field) = RREFTrait() +matrix_normal_form_type(::Type{<:Field}) = RREFTrait() # The fflu approach is the fastest over a fraction field (see benchmarks on PR 661) -matrix_normal_form_type(::FracField) = FFLUTrait() -matrix_normal_form_type(::AbstractAlgebra.Rationals{BigInt}) = FFLUTrait() -matrix_normal_form_type(::FracField{T}) where {T <: PolyRingElem} = MatrixInterpolateTrait() +matrix_normal_form_type(::Type{<:FracField}) = FFLUTrait() +matrix_normal_form_type(::Type{<:AbstractAlgebra.Rationals{BigInt}}) = FFLUTrait() +matrix_normal_form_type(::Type{<:FracField{T}}) where {T <: PolyRingElem} = MatrixInterpolateTrait() -matrix_normal_form_type(A::MatElem) = matrix_normal_form_type(base_ring(A)) +matrix_normal_form_type(T::Type{<:MatElem}) = matrix_normal_form_type(base_ring_type(T)) + +matrix_normal_form_type(x) = matrix_normal_form_type(typeof(x)) +matrix_normal_form_type(T::DataType) = throw(MethodError(matrix_normal_form_type, (T,))) ################################################################################ # @@ -207,29 +210,29 @@ function solve_init(NF::MatrixNormalFormTrait, A::MatElem) return solve_context_type(NF, base_ring(A))(A) end -# For a ring R, the following signatures of `solve_context_type` need to be +# For a ring R of type T, the following signatures of `solve_context_type` need to be # implemented: -# 1) solve_context_type(R) -# 2) solve_context_type(::MatrixNormalFormTrait, elem_type(R)) +# 1) solve_context_type(::Type{<:T}) +# 2) solve_context_type(::MatrixNormalFormTrait, elem_type(T)) # Version 1 should pick a matrix_normal_form_type and call 2 -function solve_context_type(R::NCRing) - return solve_context_type(matrix_normal_form_type(R), elem_type(R)) +function solve_context_type(T::Type{<:NCRing}) + return solve_context_type(matrix_normal_form_type(T), elem_type(T)) end -function solve_context_type(K::Field) - # matrix_normal_form_type(K) would be RREFTrait, but we want to use +function solve_context_type(T::Type{<:Field}) + # matrix_normal_form_type(T) would be RREFTrait, but we want to use # LU in solve contexts - return solve_context_type(LUTrait(), elem_type(K)) + return solve_context_type(LUTrait(), elem_type(T)) end -function solve_context_type(K::Union{AbstractAlgebra.Rationals{BigInt}, FracField}) +function solve_context_type(T::Type{<:Union{AbstractAlgebra.Rationals{BigInt}, FracField}}) # In this case, we use FFLU - return solve_context_type(FFLUTrait(), elem_type(K)) + return solve_context_type(FFLUTrait(), elem_type(T)) end -function solve_context_type(A::MatElem) - return solve_context_type(base_ring(A)) +function solve_context_type(T::Type{<:MatElem}) + return solve_context_type(base_ring_type(T)) end function solve_context_type(NF::MatrixNormalFormTrait, ::Type{T}) where {T <: NCRingElement} @@ -237,6 +240,10 @@ function solve_context_type(NF::MatrixNormalFormTrait, ::Type{T}) where {T <: NC return SolveCtx{T, typeof(NF), MatType, MatType, LazyTransposeMatElem{T, MatType}} end +function solve_context_type(NF::MatrixNormalFormTrait, T::Type{<:MatElem}) + return solve_context_type(NF, base_ring_type(T)) +end + function solve_context_type(::FFLUTrait, ::Type{T}) where {T <: NCRingElement} # We assume that the ring in question is a fraction field and have to get the # type of "integral" matrices, that is, matrices over the base ring of this @@ -246,13 +253,15 @@ function solve_context_type(::FFLUTrait, ::Type{T}) where {T <: NCRingElement} return SolveCtx{T, FFLUTrait, dense_matrix_type(T), IntMatT, IntMatT} end -solve_context_type(NF::MatrixNormalFormTrait, ::T) where {T <: NCRingElement} = solve_context_type(NF, T) solve_context_type(NF::MatrixNormalFormTrait, ::Type{T}) where {T <: NCRing} = solve_context_type(NF, elem_type(T)) -solve_context_type(NF::MatrixNormalFormTrait, ::T) where {T <: NCRing} = solve_context_type(NF, elem_type(T)) -solve_context_type(NF::MatrixNormalFormTrait, ::Type{<: MatElem{T}}) where T = solve_context_type(NF, T) -solve_context_type(NF::MatrixNormalFormTrait, ::MatElem{T}) where T = solve_context_type(NF, T) -matrix_normal_form_type(C::SolveCtx{T, NF}) where {T, NF} = NF() +solve_context_type(x) = solve_context_type(typeof(x)) +solve_context_type(NF::MatrixNormalFormTrait, x) = solve_context_type(NF, typeof(x)) +solve_context_type(T::DataType) = throw(MethodError(solve_context_type, (T,))) +solve_context_type(NF::MatrixNormalFormTrait, T::DataType) = throw(MethodError(solve_context_type, (NF, T))) + + +matrix_normal_form_type(::Type{<:SolveCtx{T, NF}}) where {T, NF} = NF() matrix(C::SolveCtx) = C.A diff --git a/test/Solve-test.jl b/test/Solve-test.jl index 849d3d37f7..93446e92e2 100644 --- a/test/Solve-test.jl +++ b/test/Solve-test.jl @@ -103,13 +103,17 @@ end if is_default C = solve_init(M) + @test AbstractAlgebra.Solve.matrix_normal_form_type(typeof(C)) === NFTrait() @test AbstractAlgebra.Solve.matrix_normal_form_type(C) === NFTrait() + @test C isa AbstractAlgebra.solve_context_type(typeof(R)) @test C isa AbstractAlgebra.solve_context_type(R) + @test C isa AbstractAlgebra.solve_context_type(typeof(M)) @test C isa AbstractAlgebra.solve_context_type(M) end C = solve_init(NFTrait(), M) + @test AbstractAlgebra.Solve.matrix_normal_form_type(typeof(C)) === NFTrait() @test AbstractAlgebra.Solve.matrix_normal_form_type(C) === NFTrait() @test C isa AbstractAlgebra.solve_context_type(NFTrait(), elem_type(R)) @test C isa AbstractAlgebra.solve_context_type(NFTrait(), R(1))