diff --git a/src/implementations/BigInt.jl b/src/implementations/BigInt.jl index b2cfbb7..83d3d6f 100644 --- a/src/implementations/BigInt.jl +++ b/src/implementations/BigInt.jl @@ -27,6 +27,14 @@ function operate_to!(output::BigInt, ::typeof(+), a::BigInt, b::BigInt) return Base.GMP.MPZ.add!(output, a, b) end +function operate_to!(output::BigInt, ::typeof(copy), a::BigInt) + return Base.GMP.MPZ.set!(output, a) +end + +function operate_to!(output::BigInt, ::typeof(copy), a::Int) + return Base.GMP.MPZ.set_si!(output, a) +end + # - promote_operation(::typeof(-), ::Vararg{Type{BigInt},N}) where {N} = BigInt @@ -35,6 +43,10 @@ function operate_to!(output::BigInt, ::typeof(-), a::BigInt, b::BigInt) return Base.GMP.MPZ.sub!(output, a, b) end +function operate_to!(output::BigInt, ::typeof(-), a::BigInt) + return Base.GMP.MPZ.neg!(output, a) +end + # * promote_operation(::typeof(*), ::Vararg{Type{BigInt},N}) where {N} = BigInt diff --git a/src/implementations/Rational.jl b/src/implementations/Rational.jl index 8bdd6d5..0e0aa44 100644 --- a/src/implementations/Rational.jl +++ b/src/implementations/Rational.jl @@ -30,13 +30,29 @@ end # + function promote_operation( - ::typeof(+), + ::Union{typeof(+),typeof(-)}, ::Type{Rational{S}}, ::Type{Rational{T}}, ) where {S,T} return Rational{promote_sum_mul(S, T)} end +function promote_operation( + op::Union{typeof(+),typeof(-)}, + ::Type{Rational{S}}, + ::Type{I}, +) where {S,I<:Integer} + return promote_operation(op, Rational{S}, Rational{I}) +end + +function promote_operation( + op::Union{typeof(+),typeof(-)}, + ::Type{I}, + ::Type{Rational{S}}, +) where {S,I<:Integer} + return promote_operation(op, Rational{S}, Rational{I}) +end + function operate_to!(output::Rational, ::typeof(+), x::Rational, y::Rational) xd, yd = Base.divgcd(promote(x.den, y.den)...) # TODO: Use `checked_mul` and `checked_add` like in Base @@ -46,16 +62,28 @@ function operate_to!(output::Rational, ::typeof(+), x::Rational, y::Rational) return output end -# - +function operate_to!(output::Rational, ::typeof(+), x::Rational, y::Integer) + # TODO Use `checked_mul` and `checked_add` like in Base + operate_to!(output.num, *, x.den, y) + operate!(+, output.num, x.num) + operate_to!(output.den, *, x.den, oftype(x.den, 1)) + return output +end -function promote_operation( - ::typeof(-), - ::Type{Rational{S}}, - ::Type{Rational{T}}, -) where {S,T} - return Rational{promote_sum_mul(S, T)} +function operate_to!(output::Rational, ::typeof(+), y::Integer, x::Rational) + return operate_to!(output, +, x, y) +end + +# unary - + +function operate_to!(output::Rational, ::typeof(-), x::Rational) + operate_to!(output.num, -, x.num) + operate_to!(output.den, copy, x.den) + return output end +# binary - + function operate_to!(output::Rational, ::typeof(-), x::Rational, y::Rational) xd, yd = Base.divgcd(promote(x.den, y.den)...) # TODO: Use `checked_mul` and `checked_sub` like in Base @@ -65,6 +93,22 @@ function operate_to!(output::Rational, ::typeof(-), x::Rational, y::Rational) return output end +function operate_to!(output::Rational, ::typeof(-), x::Rational, y::Integer) + # TODO Use `checked_mul` and `checked_sub` like in Base + operate_to!(output.num, *, x.den, y) + operate!(-, output.num) + operate!(+, output.num, x.num) + operate_to!(output.den, copy, x.den) + return output +end + +function operate_to!(output::Rational, ::typeof(-), y::Integer, x::Rational) + # TODO Use `checked_mul` and `checked_sub` like in Base + operate_to!(output, -, x, y) + operate_to!(output, -, output) + return output +end + # * function promote_operation( @@ -75,6 +119,22 @@ function promote_operation( return Rational{promote_operation(*, S, T)} end +function promote_operation( + ::typeof(*), + ::Type{Rational{S}}, + ::Type{I}, +) where {S,I<:Integer} + return promote_operation(*, Rational{S}, Rational{I}) +end + +function promote_operation( + ::typeof(*), + ::Type{I}, + ::Type{Rational{S}}, +) where {S,I<:Integer} + return promote_operation(*, Rational{S}, Rational{I}) +end + function operate_to!(output::Rational, ::typeof(*), x::Rational, y::Rational) xn, yd = Base.divgcd(promote(x.num, y.den)...) xd, yn = Base.divgcd(promote(x.den, y.num)...) @@ -83,6 +143,69 @@ function operate_to!(output::Rational, ::typeof(*), x::Rational, y::Rational) return output end +function operate_to!(output::Rational, ::typeof(*), x::Rational, y::Integer) + xn = x.num + xd, yn = Base.divgcd(promote(x.den, y)...) + operate_to!(output.num, *, xn, yn) + operate_to!(output.den, copy, x.den) + return output +end + +function operate_to!(output::Rational, ::typeof(*), y::Integer, x::Rational) + return operate_to!(output, *, x, y) +end + +# // + +function operate_to!( + output::Rational, + op::Union{typeof(/),typeof(//)}, + x::Rational, + y::Rational, +) + xn, yn = Base.divgcd(promote(x.num, y.num)...) + xd, yd = Base.divgcd(promote(x.den, y.den)...) + operate_to!(output.num, *, xn, yd) + operate_to!(output.den, *, xd, yn) + return output +end + +function operate_to!( + output::Rational, + op::Union{typeof(/),typeof(//)}, + x::Rational, + y::Integer, +) + xn, yn = Base.divgcd(promote(x.num, y)...) + operate_to!(output.num, copy, xn) + operate_to!(output.den, *, x.den, yn) + return output +end + +function operate_to!( + output::Rational, + op::Union{typeof(/),typeof(//)}, + x::Integer, + y::Rational, +) + xn, yd = Base.divgcd(promote(x, y.den)...) + operate_to!(output.num, *, xn, yd) + operate_to!(output.den, copy, y.num) + return output +end + +function operate_to!( + output::Rational, + op::Union{typeof(/),typeof(//)}, + x::Integer, + y::Integer, +) + n, d = Base.divgcd(promote(x, y)...) + operate_to!(output.num, copy, n) + operate_to!(output.den, copy, d) + return output +end + # gcd function promote_operation( diff --git a/src/interface.jl b/src/interface.jl index fc1c5f5..dc24275 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -20,28 +20,22 @@ function promote_operation_fallback( end function promote_operation_fallback( - ::typeof(/), + op::Function, ::Type{S}, ::Type{T}, ) where {S,T} - return typeof(zero(S) / oneunit(T)) + U = Base.promote_op(op, S, T) + return return U == Union{} ? typeof(op(oneunit(S), oneunit(T))) : U end # Julia v1.0.x has trouble with inference with the `Vararg` method, see # https://travis-ci.org/jump-dev/JuMP.jl/jobs/617606373 -function promote_operation_fallback( - op::F, - ::Type{S}, - ::Type{T}, -) where {F<:Function,S,T} - return typeof(op(zero(S), zero(T))) -end - function promote_operation_fallback( op::F, args::Vararg{Type,N}, ) where {F<:Function,N} - return typeof(op(zero.(args)...)) + U = Base.promote_op(op, args...) + return return U == Union{} ? typeof(op(oneunit.(args)...)) : U end promote_operation_fallback(::typeof(*), ::Type{T}) where {T} = T @@ -172,9 +166,7 @@ function operate( ) where {N} return op(x, y, args...) end - -operate(op::Union{typeof(-),typeof(/)}, x, y) where {N} = op(x, y) - +operate(op::Union{typeof(-),typeof(/),typeof(//)}, x, y) = op(x, y) operate(::typeof(convert), ::Type{T}, x) where {T} = convert(T, x) operate(::typeof(convert), ::Type{T}, x::T) where {T} = copy_if_mutable(x) diff --git a/test/broadcast.jl b/test/broadcast.jl index 37ea93b..5be1836 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -28,7 +28,7 @@ end if VERSION >= v"1.5" # FIXME This should not allocate but I couldn't figure out where these # 240 come from. - alloc_test(() -> MA.broadcast!!(+, a, b), 240) + alloc_test(() -> MA.broadcast!!(+, a, b), 80) alloc_test(() -> MA.broadcast!!(+, a, c), 0) end end diff --git a/test/rational.jl b/test/rational.jl new file mode 100644 index 0000000..9fb00e6 --- /dev/null +++ b/test/rational.jl @@ -0,0 +1,21 @@ +for op in (+, -, *, //) + for (a,b) in ( + (2 // 3, 5), + (2, 3 // 5), + (2 // 3, 5 // 7), + (big(2) // 3, 5), + (big(2), 3 // 5), + (big(2) // 3, 5 // 7), + ) + @test MA.operate_to!!(MA.copy_if_mutable(op(a, b)), op, a, b) == + op(a, b) + @test MA.operate_to!!(MA.copy_if_mutable(op(b, a)), op, b, a) == + op(b, a) + end +end + +op = // +for (a, b) in ((2, 3), (big(2), 3), (2, big(3))) + @test MA.operate_to!!(MA.copy_if_mutable(op(a, b)), op, a, b) == op(a, b) + @test MA.operate_to!!(MA.copy_if_mutable(op(b, a)), op, b, a) == op(b, a) +end diff --git a/test/runtests.jl b/test/runtests.jl index 442c208..c45d22d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,6 +14,11 @@ end @testset "BigInt" begin include("big.jl") end + +@testset "Rational" begin + include("rational.jl") +end + @testset "Broadcast" begin include("broadcast.jl") end