Skip to content

Commit ccdc063

Browse files
authored
Merge pull request #25 from ericphanson/eph/fcollect
Fix `fcollect` to respect object identity instead of `==`; document order
2 parents 059a5dc + bd72671 commit ccdc063

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Functors"
22
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
33
authors = ["Mike J Innes <[email protected]>"]
4-
version = "0.2.5"
4+
version = "0.2.6"
55

66
[compat]
77
julia = "1"

src/functor.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x))
149149
fcollect(x; exclude = v -> false)
150150
151151
Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref)
152-
and collecting the results into a flat array.
152+
and collecting the results into a flat array, ordered by a breadth-first
153+
traversal of `x`, respecting the iteration order of `children` calls.
153154
154155
Doesn't recurse inside branches rooted at nodes `v`
155156
for which `exclude(v) == true`.
@@ -192,11 +193,15 @@ julia> fcollect(m, exclude = v -> Functors.isleaf(v))
192193
Bar([1, 2, 3])
193194
```
194195
"""
195-
function fcollect(x; cache = [], exclude = v -> false)
196-
x in cache && return cache
197-
if !exclude(x)
198-
push!(cache, x)
199-
foreach(y -> fcollect(y; cache = cache, exclude = exclude), children(x))
200-
end
201-
return cache
196+
function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false)
197+
# note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache
198+
# (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector`
199+
# for the results, to preserve traversal order (important downstream!).
200+
x in cache && return output
201+
if !exclude(x)
202+
push!(cache, x)
203+
push!(output, x)
204+
foreach(y -> fcollect(y; cache=cache, output=output, exclude=exclude), children(x))
205+
end
206+
return output
202207
end

test/basics.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ end
8484
m3 = Foo(m2, m0)
8585
m4 = Bar(m3)
8686
@test all(fcollect(m4) .=== [m4, m3, m2, m1, m0])
87+
88+
m1 = [1, 2, 3]
89+
m2 = [1, 2, 3]
90+
m3 = Foo(m1, m2)
91+
@test all(fcollect(m3) .=== [m3, m1, m2])
8792
end
8893

8994
struct FFoo

0 commit comments

Comments
 (0)