Skip to content

Commit 22f4e21

Browse files
eliascarvjuliohm
andauthored
Use the inverse function instead of the Base.inv function (#4)
* Use the 'inverse' function (from InverseFunctions.jl) instead of 'inv' function (from Base) * Define the 'inverse' function * Fix code * Update docstring * Update src/interface.jl --------- Co-authored-by: Júlio Hoffimann <[email protected]>
1 parent b60d2ea commit 22f4e21

File tree

4 files changed

+28
-25
lines changed

4 files changed

+28
-25
lines changed

src/identity.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ isrevertible(::Type{Identity}) = true
1313

1414
isinvertible(::Type{Identity}) = true
1515

16-
Base.inv(::Identity) = Identity()
16+
inverse(::Identity) = Identity()
1717

1818
apply(::Identity, object) = object, nothing
1919

src/interface.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,26 @@ function isrevertible end
3232
"""
3333
isinvertible(transform)
3434
35-
Tells whether or not the `transform` is invertible, i.e. supports a
36-
`inv` function. Defaults to `false` for new transform types.
35+
Tells whether or not the `transform` is invertible, i.e. whether it
36+
implements the `inverse` function. Defaults to `false` for new transform
37+
types.
3738
3839
Transforms can be invertible in the mathematical sense, i.e., there
3940
exists a one-to-one mapping between input and output spaces.
4041
41-
See also [`isrevertible`](@ref).
42+
See also [`inverse`](@ref), [`isrevertible`](@ref).
4243
"""
4344
function isinvertible end
4445

46+
"""
47+
inverse(transform)
48+
49+
Returns the inverse transform of the `transform`.
50+
51+
See also [`isinvertible`](@ref).
52+
"""
53+
function inverse end
54+
4555
"""
4656
assertions(transform)
4757
@@ -90,23 +100,19 @@ function reapply end
90100
# TRANSFORM FALLBACKS
91101
# --------------------
92102

93-
isrevertible(transform::Transform) =
94-
isrevertible(typeof(transform))
103+
isrevertible(transform::Transform) = isrevertible(typeof(transform))
95104
isrevertible(::Type{<:Transform}) = false
96105

97-
isinvertible(transform::Transform) =
98-
isinvertible(typeof(transform))
106+
isinvertible(transform::Transform) = isinvertible(typeof(transform))
99107
isinvertible(::Type{<:Transform}) = false
100108

101109
assertions(transform::Transform) = []
102110

103111
preprocess(transform::Transform, object) = nothing
104112

105-
reapply(transform::Transform, object, cache) =
106-
apply(transform, object) |> first
113+
reapply(transform::Transform, object, cache) = apply(transform, object) |> first
107114

108-
(transform::Transform)(object) =
109-
apply(transform, object) |> first
115+
(transform::Transform)(object) = apply(transform, object) |> first
110116

111117
function Base.show(io::IO, transform::Transform)
112118
T = typeof(transform)

src/sequential.jl

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ isrevertible(s::SequentialTransform) = all(isrevertible, s.transforms)
1515

1616
isinvertible(s::SequentialTransform) = all(isinvertible, s.transforms)
1717

18-
Base.inv(s::SequentialTransform) = SequentialTransform([inv(t) for t in reverse(s.transforms)])
18+
inverse(s::SequentialTransform) = SequentialTransform([inverse(t) for t in reverse(s.transforms)])
1919

2020
function apply(s::SequentialTransform, table)
2121
allcache = []
@@ -59,21 +59,16 @@ end
5959
Create a [`SequentialTransform`](@ref) transform with
6060
`[transform₁, transform₂, …, transformₙ]`.
6161
"""
62-
(t1::Transform, t2::Transform) =
63-
SequentialTransform([t1, t2])
64-
(t1::Transform, t2::SequentialTransform) =
65-
SequentialTransform([t1; t2.transforms])
66-
(t1::SequentialTransform, t2::Transform) =
67-
SequentialTransform([t1.transforms; t2])
68-
(t1::SequentialTransform, t2::SequentialTransform) =
69-
SequentialTransform([t1.transforms; t2.transforms])
62+
(t1::Transform, t2::Transform) = SequentialTransform([t1, t2])
63+
(t1::Transform, t2::SequentialTransform) = SequentialTransform([t1; t2.transforms])
64+
(t1::SequentialTransform, t2::Transform) = SequentialTransform([t1.transforms; t2])
65+
(t1::SequentialTransform, t2::SequentialTransform) = SequentialTransform([t1.transforms; t2.transforms])
7066

7167
# AbstractTrees interface
7268
AbstractTrees.nodevalue(::SequentialTransform) = SequentialTransform
7369
AbstractTrees.children(s::SequentialTransform) = s.transforms
7470

75-
Base.show(io::IO, s::SequentialTransform) =
76-
print(io, join(s.transforms, ""))
71+
Base.show(io::IO, s::SequentialTransform) = print(io, join(s.transforms, ""))
7772

7873
function Base.show(io::IO, ::MIME"text/plain", s::SequentialTransform)
7974
tree = AbstractTrees.repr_tree(s, context=io)

test/runtests.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ using Test
44
@testset "TransformsBase.jl" begin
55
@test TransformsBase.isrevertible(Identity())
66
@test TransformsBase.isinvertible(Identity())
7-
@test inv(Identity()) == Identity()
8-
@test inv(Identity() Identity()) == Identity()
7+
@test TransformsBase.inverse(Identity()) == Identity()
8+
@test TransformsBase.inverse(Identity() Identity()) == Identity()
99
@test (Identity() Identity()) == Identity()
1010

1111
# test fallbacks
@@ -14,6 +14,8 @@ using Test
1414
T = TestTransform()
1515
@test !TransformsBase.isrevertible(T)
1616
@test !TransformsBase.isinvertible(T)
17+
@test !TransformsBase.isrevertible(T T)
18+
@test !TransformsBase.isinvertible(T T)
1719
@test TransformsBase.assertions(T) |> isempty
1820
@test TransformsBase.preprocess(T, nothing) |> isnothing
1921
@test TransformsBase.reapply(T, 1, nothing) == 1

0 commit comments

Comments
 (0)