Skip to content

Non-mutating versions of pop, popfirst, etc. (#66) #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 14, 2020
Merged
99 changes: 98 additions & 1 deletion src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Requires
using LinearAlgebra
using SparseArrays

using Base: OneTo
using Base: OneTo, @propagate_inbounds

Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
parameterless_type(x) = parameterless_type(typeof(x))
Expand Down Expand Up @@ -543,6 +543,103 @@ function restructure(x::Array,y)
reshape(convert(Array,y),size(x)...)
end

"""
insert(collection, index, item)

Return a new instance of `collection` with `item` inserted into at the given `index`.
"""
Base.@propagate_inbounds function insert(collection, index, item)
@boundscheck checkbounds(collection, index)
ret = similar(collection, length(collection) + 1)
@inbounds for i in firstindex(ret):(index - 1)
ret[i] = collection[i]
end
@inbounds ret[index] = item
@inbounds for i in (index + 1):lastindex(ret)
ret[i] = collection[i - 1]
end
return ret
end

function insert(x::Tuple, index::Integer, item)
@boundscheck if !checkindex(Bool, static_first(x):static_last(x), index)
throw(BoundsError(x, index))
end
return unsafe_insert(x, Int(index), item)
end

@inline function unsafe_insert(x::Tuple, i::Int, item)
if i === 1
return (item, x...)
else
return (first(x), unsafe_insert(Base.tail(x), i - 1, item)...)
end
end

"""
deleteat(collection, index)

Return a new instance of `collection` with the item at the given `index` removed.
"""
@propagate_inbounds function deleteat(collection::AbstractVector, index)
@boundscheck if !checkindex(Bool, eachindex(collection), index)
throw(BoundsError(collection, index))
end
return unsafe_deleteat(collection, index)
end
@propagate_inbounds function deleteat(collection::Tuple, index)
@boundscheck if !checkindex(Bool, static_first(collection):static_last(collection), index)
throw(BoundsError(collection, index))
end
return unsafe_deleteat(collection, index)
end

function unsafe_deleteat(src::AbstractVector, index::Integer)
dst = similar(src, length(src) - 1)
@inbounds for i in indices(dst)
if i < index
dst[i] = src[i]
else
dst[i] = src[i + 1]
end
end
return dst
end

@inline function unsafe_deleteat(src::AbstractVector, inds::AbstractVector)
dst = similar(src, length(src) - length(inds))
dst_index = firstindex(dst)
@inbounds for src_index in indices(src)
if !in(src_index, inds)
dst[dst_index] = src[src_index]
dst_index += one(dst_index)
end
end
return dst
end

@inline function unsafe_deleteat(src::Tuple, inds::AbstractVector)
dst = Vector{eltype(src)}(undef, length(src) - length(inds))
dst_index = firstindex(dst)
@inbounds for src_index in OneTo(length(src))
if !in(src_index, inds)
dst[dst_index] = src[src_index]
dst_index += one(dst_index)
end
end
return Tuple(dst)
end

@inline function unsafe_deleteat(x::Tuple, i::Integer)
if i === one(i)
return Base.tail(x)
elseif i == length(x)
return Base.front(x)
else
return (first(x), unsafe_deleteat(Base.tail(x), i - one(i))...)
end
end

function __init__()

@require SuiteSparse="4607b0f0-06f3-5cda-b6b1-a6196a1729e9" begin
Expand Down
61 changes: 21 additions & 40 deletions src/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T)

# add methods to support ArrayInterface

_get(x) = x
_get(::Static{V}) where {V} = V
_get(::Type{Static{V}}) where {V} = V
_convert(::Type{T}, x) where {T} = convert(T, x)
_convert(::Type{T}, ::Val{V}) where {T,V} = Val(convert(T, V))

"""
OptionallyStaticUnitRange{T<:Integer}(start, stop) <: OrdinalRange{T,T}

Expand All @@ -57,28 +51,23 @@ at compile time. An `OptionallyStaticUnitRange` is intended to be constructed in
from other valid indices. Therefore, users should not expect the same checks are used
to ensure construction of a valid `OptionallyStaticUnitRange` as a `UnitRange`.
"""
struct OptionallyStaticUnitRange{T <: Integer, F <: Integer, L <: Integer} <: AbstractUnitRange{T}
struct OptionallyStaticUnitRange{F <: Integer, L <: Integer} <: AbstractUnitRange{Int}
start::F
stop::L

function OptionallyStaticUnitRange{T}(start, stop) where {T<:Real}
if _get(start) isa T
if _get(stop) isa T
return new{T,typeof(start),typeof(stop)}(start, stop)
function OptionallyStaticUnitRange(start, stop)
if eltype(start) <: Int
if eltype(stop) <: Int
return new{typeof(start),typeof(stop)}(start, stop)
else
return OptionallyStaticUnitRange{T}(start, _convert(T, stop))
return OptionallyStaticUnitRange(start, Int(stop))
end
else
return OptionallyStaticUnitRange{T}(_convert(T, start), stop)
return OptionallyStaticUnitRange(Int(start), stop)
end
end

function OptionallyStaticUnitRange(start, stop)
T = promote_type(typeof(_get(start)), typeof(_get(stop)))
return OptionallyStaticUnitRange{T}(start, stop)
end

function OptionallyStaticUnitRange(x::AbstractRange)
function OptionallyStaticUnitRange(x::AbstractRange)
if step(x) == 1
fst = static_first(x)
lst = static_last(x)
Expand All @@ -94,12 +83,12 @@ Base.:(:)(::Static{L}, U::Integer) where {L} = OptionallyStaticUnitRange(Static(
Base.:(:)(::Static{L}, ::Static{U}) where {L,U} = OptionallyStaticUnitRange(Static(L), Static(U))

Base.first(r::OptionallyStaticUnitRange) = r.start
Base.step(r::OptionallyStaticUnitRange{T}) where {T} = oneunit(T)
Base.step(::OptionallyStaticUnitRange) = Static(1)
Base.last(r::OptionallyStaticUnitRange) = r.stop

known_first(::Type{<:OptionallyStaticUnitRange{<:Any,Static{F}}}) where {F} = F
known_step(::Type{<:OptionallyStaticUnitRange{T}}) where {T} = one(T)
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,<:Any,Static{L}}}) where {L} = L
known_first(::Type{<:OptionallyStaticUnitRange{Static{F}}}) where {F} = F
known_step(::Type{<:OptionallyStaticUnitRange}) = 1
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,Static{L}}}) where {L} = L

function Base.isempty(r::OptionallyStaticUnitRange)
if known_first(r) === oneunit(eltype(r))
Expand All @@ -112,10 +101,8 @@ end
unsafe_isempty_one_to(lst) = lst <= zero(lst)
unsafe_isempty_unit_range(fst, lst) = fst > lst

unsafe_isempty_unit_range(fst::T, lst::T) where {T} = Integer(lst - fst + one(T))

unsafe_length_one_to(lst::T) where {T<:Int} = T(lst)
unsafe_length_one_to(lst::T) where {T} = Integer(lst - zero(lst))
unsafe_length_one_to(lst::Int) = lst
unsafe_length_one_to(::Static{L}) where {L} = lst

Base.@propagate_inbounds function Base.getindex(r::OptionallyStaticUnitRange, i::Integer)
if known_first(r) === oneunit(r)
Expand Down Expand Up @@ -144,15 +131,15 @@ end
@inline _try_static(::Static{M}, ::Static{N}) where {M, N} = @assert false "Unequal Indices: Static{$M}() != Static{$N}()"
function _try_static(::Static{N}, x) where {N}
@assert N == x "Unequal Indices: Static{$N}() != x == $x"
Static{N}()
return Static{N}()
end
function _try_static(x, ::Static{N}) where {N}
@assert N == x "Unequal Indices: x == $x != Static{$N}()"
Static{N}()
return Static{N}()
end
function _try_static(x, y)
@assert x == y "Unequal Indicess: x == $x != $y == y"
x
return x
end

###
Expand All @@ -172,24 +159,19 @@ end
end
end

function Base.length(r::OptionallyStaticUnitRange{T}) where {T}
function Base.length(r::OptionallyStaticUnitRange)
if isempty(r)
return zero(T)
return 0
else
if known_one(r) === one(T)
if known_first(r) === 0
return unsafe_length_one_to(last(r))
else
return unsafe_length_unit_range(first(r), last(r))
end
end
end

function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{Int,Int64,Int128}}
return Base.checked_add(Base.checked_sub(lst, fst), one(T))
end
function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{UInt,UInt64,UInt128}}
return Base.checked_add(lst - fst, one(T))
end
unsafe_length_unit_range(start::Integer, stop::Integer) = Int(start - stop + 1)

"""
indices(x[, d])
Expand Down Expand Up @@ -231,4 +213,3 @@ end
lst = _try_static(static_last(x), static_last(y))
return Base.Slice(OptionallyStaticUnitRange(fst, lst))
end

49 changes: 28 additions & 21 deletions src/static.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ Use `Static(N)` instead of `Val(N)` when you want it to behave like a number.
struct Static{N} <: Integer
Static{N}() where {N} = new{N::Int}()
end

const Zero = Static{0}
const One = Static{1}

Base.@pure Static(N::Int) = Static{N}()
Static(N::Integer) = Static(convert(Int, N))
Static(::Static{N}) where {N} = Static{N}()
Expand Down Expand Up @@ -33,41 +37,44 @@ end
Base.promote_rule(::Type{<:Static}, ::Type{<:Static}) = Int
Base.:(%)(::Static{N}, ::Type{Integer}) where {N} = N

Base.iszero(::Static{0}) = true
Base.eltype(::Type{T}) where {T<:Static} = Int
Base.iszero(::Zero) = true
Base.iszero(::Static) = false
Base.isone(::Static{1}) = true
Base.isone(::One) = true
Base.isone(::Static) = false
Base.zero(::Type{T}) where {T<:Static} = Zero()
Base.one(::Type{T}) where {T<:Static} = One()

for T = [:Real, :Rational, :Integer]
@eval begin
@inline Base.:(+)(i::$T, ::Static{0}) = i
@inline Base.:(+)(i::$T, ::Zero) = i
@inline Base.:(+)(i::$T, ::Static{M}) where {M} = i + M
@inline Base.:(+)(::Static{0}, i::$T) = i
@inline Base.:(+)(::Zero, i::$T) = i
@inline Base.:(+)(::Static{M}, i::$T) where {M} = M + i
@inline Base.:(-)(i::$T, ::Static{0}) = i
@inline Base.:(-)(i::$T, ::Zero) = i
@inline Base.:(-)(i::$T, ::Static{M}) where {M} = i - M
@inline Base.:(*)(i::$T, ::Static{0}) = Static{0}()
@inline Base.:(*)(i::$T, ::Static{1}) = i
@inline Base.:(*)(i::$T, ::Zero) = Zero()
@inline Base.:(*)(i::$T, ::One) = i
@inline Base.:(*)(i::$T, ::Static{M}) where {M} = i * M
@inline Base.:(*)(::Static{0}, i::$T) = Static{0}()
@inline Base.:(*)(::Static{1}, i::$T) = i
@inline Base.:(*)(::Zero, i::$T) = Zero()
@inline Base.:(*)(::One, i::$T) = i
@inline Base.:(*)(::Static{M}, i::$T) where {M} = M * i
end
end
@inline Base.:(+)(::Static{0}, ::Static{0}) = Static{0}()
@inline Base.:(+)(::Static{0}, ::Static{M}) where {M} = Static{M}()
@inline Base.:(+)(::Static{M}, ::Static{0}) where {M} = Static{M}()
@inline Base.:(+)(::Zero, ::Zero) = Zero()
@inline Base.:(+)(::Zero, ::Static{M}) where {M} = Static{M}()
@inline Base.:(+)(::Static{M}, ::Zero) where {M} = Static{M}()

@inline Base.:(-)(::Static{M}, ::Static{0}) where {M} = Static{M}()
@inline Base.:(-)(::Static{M}, ::Zero) where {M} = Static{M}()

@inline Base.:(*)(::Static{0}, ::Static{0}) = Static{0}()
@inline Base.:(*)(::Static{1}, ::Static{0}) = Static{0}()
@inline Base.:(*)(::Static{0}, ::Static{1}) = Static{0}()
@inline Base.:(*)(::Static{1}, ::Static{1}) = Static{1}()
@inline Base.:(*)(::Static{M}, ::Static{0}) where {M} = Static{0}()
@inline Base.:(*)(::Static{0}, ::Static{M}) where {M} = Static{0}()
@inline Base.:(*)(::Static{M}, ::Static{1}) where {M} = Static{M}()
@inline Base.:(*)(::Static{1}, ::Static{M}) where {M} = Static{M}()
@inline Base.:(*)(::Zero, ::Zero) = Zero()
@inline Base.:(*)(::One, ::Zero) = Zero()
@inline Base.:(*)(::Zero, ::One) = Zero()
@inline Base.:(*)(::One, ::One) = One()
@inline Base.:(*)(::Static{M}, ::Zero) where {M} = Zero()
@inline Base.:(*)(::Zero, ::Static{M}) where {M} = Zero()
@inline Base.:(*)(::Static{M}, ::One) where {M} = Static{M}()
@inline Base.:(*)(::One, ::Static{M}) where {M} = Static{M}()
for f ∈ [:(+), :(-), :(*), :(/), :(÷), :(%), :(<<), :(>>), :(>>>), :(&), :(|), :(⊻)]
@eval @generated Base.$f(::Static{M}, ::Static{N}) where {M,N} = Expr(:call, Expr(:curly, :Static, $f(M, N)))
end
Expand Down
22 changes: 22 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ end
@testset "Static" begin
@test iszero(Static(0))
@test !iszero(Static(1))
@test @inferred(one(Static)) === Static(1)
@test @inferred(zero(Static)) === Static(0)
@test eltype(one(Static)) <: Int
# test for ambiguities and correctness
for i ∈ [Static(0), Static(1), Static(2), 3]
for j ∈ [Static(0), Static(1), Static(2), 3]
Expand All @@ -271,3 +274,22 @@ end
end
end

@testset "insert/deleteat" begin
@test @inferred(ArrayInterface.insert([1,2,3], 2, -2)) == [1, -2, 2, 3]
@test @inferred(ArrayInterface.deleteat([1, 2, 3], 2)) == [1, 3]

@test @inferred(ArrayInterface.deleteat([1, 2, 3], [1, 2])) == [3]
@test @inferred(ArrayInterface.deleteat([1, 2, 3], [1, 3])) == [2]
@test @inferred(ArrayInterface.deleteat([1, 2, 3], [2, 3])) == [1]


@test @inferred(ArrayInterface.insert((1,2,3), 1, -2)) == (-2, 1, 2, 3)
@test @inferred(ArrayInterface.insert((1,2,3), 2, -2)) == (1, -2, 2, 3)
@test @inferred(ArrayInterface.insert((1,2,3), 3, -2)) == (1, 2, -2, 3)

@test @inferred(ArrayInterface.deleteat((1, 2, 3), 1)) == (2, 3)
@test @inferred(ArrayInterface.deleteat((1, 2, 3), 2)) == (1, 3)
@test @inferred(ArrayInterface.deleteat((1, 2, 3), 3)) == (1, 2)
@test ArrayInterface.deleteat((1, 2, 3), [1, 2]) == (3,)
end