66
66
# #### map
67
67
# ####
68
68
69
- # `unzip_map` can use `StructArrays.components(StructArray(Iterators.map(f, args...)))`,
70
- # will be useful for the gradient of `map` etc.
71
-
72
-
73
69
"""
74
- unzip_map(f, args...)
70
+ unzip_map(f, args...)
75
71
76
72
For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`,
77
73
but performed using `StructArrays` for efficiency.
@@ -86,40 +82,36 @@ function unzip_map(f::F, args...) where {F}
86
82
end
87
83
88
84
unzip_map (f:: F , args:: Tuple... ) where {F} = unzip (map (f, args... ))
85
+ # unzip_map(f::F, args::NamedTuple...) where {F} = unzip(map(f, args...))
89
86
90
87
unzip_map (f:: F , args:: AbstractGPUArray... ) where {F} = unzip (map (f, args... ))
91
88
89
+ """
90
+ unzip_map_reversed(f, args...)
91
+
92
+ For a pure function `f` which returns a tuple, this is `== unzip(map(f, args...))`.
93
+ But the order of evaluation is should be the reverse.
94
+ Does NOT handle `zip`-like behaviour.
95
+ """
92
96
function unzip_map_reversed (f:: F , args... ) where {F}
93
97
T = Broadcast. combine_eltypes (f, args)
94
98
if isconcretetype (T)
95
99
T <: Tuple || throw (ArgumentError (""" unzip_map_reversed(f, args) only works on functions returning a tuple,
96
100
but f = $(sprint (show, f)) returns type T = $T """ ))
97
101
end
98
102
len1 = length (first (args))
99
- if all (a -> length (a)== len1, args)
100
- rev_args = map (Iterators. reverse, args)
101
- outs = StructArrays. components (StructArray (Iterators. map (f, rev_args... )))
102
- else
103
- len = minimum (length, args)
104
- rev_args = map (a -> Iterators. reverse (@view a[begin : begin + len- 1 ]), args)
105
- outs = StructArrays. components (StructArray (Iterators. map (f, rev_args... )))
106
- end
107
- return map (reverse!!, outs)
103
+ all (a -> length (a)== len1, args) || error (" unzip_map_reversed does not handle zip-like behaviour." )
104
+ return map (reverse!!, unzip_map (f, map (_safereverse, args)... ))
108
105
end
109
106
107
+ # This avoids MethodError: no method matching iterate(::Base.Iterators.Reverse{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}) on 1.6
108
+ _safereverse (x) = VERSION > v " 1.7" ? Iterators. reverse (x) : reverse (x)
109
+
110
110
function unzip_map_reversed (f:: F , args:: Tuple... ) where {F}
111
- len = minimum (length, args)
112
- rev_args = map (a -> reverse (a[1 : len]), args)
113
- # vlen = Val(len)
114
- # rev_args = map(args) do a
115
- # reverse(ntuple(i -> a[i], vlen)) # does not infer better
116
- # end
117
- return map (reverse, unzip (map (f, rev_args... )))
111
+ len1 = length (first (args))
112
+ all (a -> length (a)== len1, args) || error (" unzip_map_reversed does not handle zip-like behaviour." )
113
+ return map (reverse, unzip (map (f, map (reverse, args)... )))
118
114
end
119
- # function unzip_map_reversed(f::F, args::Tuple{Vararg{Any, N}}...) where {F,N}
120
- # rev_args = map(reverse, args)
121
- # return map(reverse, unzip(map(f, rev_args...)))
122
- # end
123
115
124
116
"""
125
117
reverse!!(x)
@@ -135,10 +127,11 @@ function reverse!!(x::AbstractArray)
135
127
end
136
128
end
137
129
reverse!! (x:: AbstractArray{<:AbstractZero} ) = x
130
+ reverse!! (x) = reverse (x)
138
131
139
- frule ((_, xdot), :: typeof (reverse!!), x:: AbstractArray ) = reverse!! (x), reverse!! (xdot)
132
+ frule ((_, xdot), :: typeof (reverse!!), x) = reverse!! (x), reverse!! (xdot)
140
133
141
- function rrule (:: typeof (reverse!!), x:: AbstractArray )
134
+ function rrule (:: typeof (reverse!!), x)
142
135
reverse!!_back (dy) = (NoTangent (), reverse (unthunk (dy)))
143
136
return reverse!! (x), reverse!!_back
144
137
end
@@ -181,10 +174,16 @@ end
181
174
Expr (:tuple , each... )
182
175
end
183
176
184
- unzip (xs:: AbstractArray{Tuple{T}} ) where {T} = (reinterpret (T, xs),) # best case, no copy
177
+ function unzip (xs:: AbstractArray{Tuple{T}} ) where {T}
178
+ if isbitstype (T)
179
+ (reinterpret (T, xs),) # best case, no copy
180
+ else
181
+ (map (only, xs),)
182
+ end
183
+ end
185
184
186
185
@generated function unzip (xs:: AbstractArray{Ts} ) where {Ts<: Tuple }
187
- each = if count (! Base. issingletontype, Ts. parameters) < 2
186
+ each = if count (! Base. issingletontype, Ts. parameters) < 2 && all (isbitstype, Ts . parameters)
188
187
# good case, no copy of data, some trivial arrays
189
188
[Base. issingletontype (T) ? :(similar (xs, $ T)) : :(reinterpret ($ T, xs)) for T in Ts. parameters]
190
189
else
0 commit comments