@@ -139,3 +139,89 @@ function Base.join{T,N,D,Ax}(As::AxisArray{T,N,D,Ax}...; fillvalue::T=zero(T),
139139 return result
140140
141141end # join
142+
143+ function greatest_common_axis (As:: AxisArray... )
144+ length (As) == 1 && return ndims (first (As))
145+
146+ for (i, zip_axes) in enumerate (zip (axes .(As)... ))
147+ if ! all (ax -> ax == zip_axes[1 ], zip_axes[2 : end ])
148+ return i - 1
149+ end
150+ end
151+
152+ return minimum (map (ndims, As))
153+ end
154+
155+ function flatten_array_axes (array_name, array_axes)
156+ map (zip (repeated (array_name), product (map (Ax-> Ax. val, array_axes)... ))) do tup
157+ tup_name, tup_idx = tup
158+ return (tup_name, tup_idx... )
159+ end
160+ end
161+
162+ function flatten_axes (array_names, array_axes)
163+ collect (chain (map (flatten_array_axes, array_names, array_axes)... ))
164+ end
165+
166+ """
167+ flatten(As::AxisArray...) -> AxisArray
168+ flatten(last_dim::Integer, As::AxisArray...) -> AxisArray
169+
170+ Concatenates AxisArrays with equal leading axes into a single AxisArray.
171+ All additional axes in any of the arrays are flattened into a single additional
172+ CategoricalVector{Tuple} axis.
173+
174+ ### Arguments
175+
176+ * `last_dim::Integer`: (optional) the greatest common dimension to share between all input
177+ arrays. The remaining axes are flattened. If this argument is not
178+ provided, the greatest common axis found among the input arrays is
179+ used. All preceeding axes must also be common to each input array, at
180+ the same dimension. Values from 0 up to one more than the minimum
181+ number of dimensions across all input arrays are allowed.
182+ * `As::AxisArray...`: AxisArrays to be flattened together.
183+ """
184+ function flatten (As:: AxisArray... ; kwargs... )
185+ gca = greatest_common_axis (As... )
186+
187+ return _flatten (gca, As... ; kwargs... )
188+ end
189+
190+ function flatten (last_dim:: Integer , As:: AxisArray... ; kwargs... )
191+ last_dim >= 0 || throw (ArgumentError (" last_dim must be at least 0" ))
192+
193+ if last_dim > minimum (map (ndims, As))
194+ throw (ArgumentError (
195+ " There must be at least $last_dim (last_dim) axes in each argument"
196+ ))
197+ end
198+
199+ if last_dim > greatest_common_axis (As... )
200+ throw (ArgumentError (
201+ " The first $last_dim axes don't all match across all arguments"
202+ ))
203+ end
204+
205+ return _flatten (last_dim, As... ; kwargs... )
206+ end
207+
208+ function _flatten (
209+ last_dim:: Integer ,
210+ As:: AxisArray... ;
211+ array_names= 1 : length (As),
212+ axis_name= nothing ,
213+ )
214+ common_axes = axes (As[1 ])[1 : last_dim]
215+
216+ if axis_name === nothing
217+ axis_name = _defaultdimname (last_dim + 1 )
218+ elseif ! isa (axis_name, Symbol)
219+ throw (ArgumentError (" axis_name must be a Symbol" ))
220+ end
221+
222+ new_data = cat (last_dim + 1 , (view (A. data, repeated (:, last_dim + 1 )... ) for A in As). .. )
223+ new_axis = flatten_axes (array_names, map (A -> axes (A)[last_dim+ 1 : end ], As))
224+
225+ # TODO : Consider creating a SortedVector axis when all flattened axes are Dimensional
226+ return AxisArray (new_data, common_axes... , CategoricalVector (new_axis))
227+ end
0 commit comments