This repository was archived by the owner on Aug 17, 2021. It is now read-only.
forked from JuliaArrays/AxisArrays.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcombine.jl
More file actions
261 lines (201 loc) · 8.89 KB
/
combine.jl
File metadata and controls
261 lines (201 loc) · 8.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
function equalvalued(X::NTuple)
n = length(X)
allequal = true
i = 2
while allequal && i <= n
allequal = X[i] == X[i-1]
i += 1
end #while
return allequal
end #equalvalued
sizes{T<:AxisArray}(As::T...) = tuple(zip(map(a -> map(length, indices(a)), As)...)...)
matchingdims{N,T<:AxisArray}(As::NTuple{N,T}) = all(equalvalued, sizes(As...))
matchingdimsexcept{N,T<:AxisArray}(As::NTuple{N,T}, n::Int) = all(equalvalued, sizes(As...)[[1:n-1; n+1:end]])
function Base.cat{T}(n::Integer, As::AxisArray{T}...)
if n <= ndims(As[1])
matchingdimsexcept(As, n) || error("All non-concatenated axes must be identically-valued")
newaxis = Axis{axisnames(As[1])[n]}(vcat(map(A -> A.axes[n].val, As)...))
checkaxis(newaxis)
return AxisArray(cat(n, map(A->A.data, As)...), (As[1].axes[1:n-1]..., newaxis, As[1].axes[n+1:end]...))
else
matchingdims(As) || error("All axes must be identically-valued")
return AxisArray(cat(n, map(A->A.data, As)...), As[1].axes)
end #if
end #Base.cat
function axismerge{name,T}(method::Symbol, axes::Axis{name,T}...)
axisvals = if method == :inner
intersect(axisvalues(axes...)...)
elseif method == :left
axisvalues(axes[1])[1]
elseif method == :right
axisvalues(axes[end])[1]
elseif method == :outer
union(axisvalues(axes...)...)
else
error("Join method must be one of :inner, :left, :right, :outer")
end #if
isa(axistrait(axisvals), Dimensional) && sort!(axisvals)
return Axis{name}(collect(axisvals))
end
function indexmappings{N}(oldaxes::NTuple{N,Axis}, newaxes::NTuple{N,Axis})
oldvals = axisvalues(oldaxes...)
newvals = axisvalues(newaxes...)
return collect(zip(indexmapping.(oldvals, newvals)...))
end
function indexmapping(old::AbstractVector, new::AbstractVector)
before = Int[]
after = Int[]
oldperm = sortperm(old)
newperm = sortperm(new)
oldsorted = old[oldperm]
newsorted = new[newperm]
oldlength = length(old)
newlength = length(new)
oi = ni = 1
while oi <= oldlength && ni <= newlength
oldval = oldsorted[oi]
newval = newsorted[ni]
if oldval == newval
push!(before, oldperm[oi])
push!(after, newperm[ni])
oi += 1
ni += 1
elseif oldval < newval
oi += 1
else
ni += 1
end
end
return before, after
end
"""
merge(As::AxisArray...)
Combines AxisArrays with matching axis names into a single AxisArray spanning all of the axis values of the inputs. If a coordinate is defined in more than ones of the inputs, it takes its value from last input in which it appears. If a coordinate in the output array is not defined in any of the input arrays, it takes the value of the optional `fillvalue` keyword argument (default zero).
"""
function Base.merge{T,N,D,Ax}(As::AxisArray{T,N,D,Ax}...; fillvalue::T=zero(T))
resultaxes = map(as -> axismerge(:outer, as...), map(tuple, axes.(As)...))
resultdata = fill(fillvalue, length.(resultaxes)...)
result = AxisArray(resultdata, resultaxes...)
for A in As
before_idxs, after_idxs = indexmappings(A.axes, result.axes)
result.data[after_idxs...] = A.data[before_idxs...]
end
return result
end #merge
"""
join(As::AxisArray...)
Combines AxisArrays with matching axis names into a single AxisArray. Unlike `merge`, the inputs are joined along a newly created axis (optionally specified with the `newaxis` keyword argument). The `method` keyword argument can be used to specify the join type:
`:inner` - keep only those array values at axis values common to all AxisArrays to be joined
`:left` - keep only those array values at axis values present in the first AxisArray passed
`:right` - keep only those array values at axis values present in the last AxisArray passed
`:outer` (default) - keep all array values: create an AxisArray spanning all of the input axis values
If an array value in the output array is not defined in any of the input arrays (i.e. in the case of a left, right, or outer join), it takes the value of the optional `fillvalue` keyword argument (default zero).
"""
function Base.join{T,N,D,Ax}(As::AxisArray{T,N,D,Ax}...; fillvalue::T=zero(T),
newaxis::Axis=_nextaxistype(As[1].axes)(1:length(As)),
method::Symbol=:outer)
prejoin_resultaxes = map(as -> axismerge(method, as...), map(tuple, axes.(As)...))
resultaxes = (prejoin_resultaxes..., newaxis)
resultdata = fill(fillvalue, length.(resultaxes)...)
result = AxisArray(resultdata, resultaxes...)
for (i, A) in enumerate(As)
before_idxs, after_idxs = indexmappings(A.axes, prejoin_resultaxes)
result.data[(after_idxs..., i)...] = A.data[before_idxs...]
end #for
return result
end #join
function _flatten_array_axes(array_name, array_axes...)
((array_name, (idx isa Tuple ? idx : (idx,))...) for idx in product((Ax.val for Ax in array_axes)...))
end
function _flatten_axes(array_names, array_axes)
collect(Iterators.flatten(map(array_names, array_axes) do tup_name, tup_array_axes
_flatten_array_axes(tup_name, tup_array_axes...)
end))
end
function _splitall{N}(::Type{Val{N}}, As...)
tuple((Base.IteratorsMD.split(A, Val{N}) for A in As)...)
end
function _reshapeall{N}(::Type{Val{N}}, As...)
tuple((reshape(A, Val{N}) for A in As)...)
end
function _check_common_axes(common_axis_tuple)
if !all(axisname(first(common_axis_tuple)) .=== axisname.(common_axis_tuple[2:end]))
throw(ArgumentError("Leading common axes must have the same name in each array"))
end
return nothing
end
function _flat_axis_eltype(LType, trailing_axes)
eltypes = map(trailing_axes) do array_trailing_axes
Tuple{LType, eltype.(array_trailing_axes)...}
end
return typejoin(eltypes...)
end
function flatten{N, NA}(::Type{Val{N}}, As::Vararg{AxisArray, NA})
flatten(Val{N}, ntuple(identity, Val{NA}), As...)
end
"""
flatten(As::AxisArray...) -> AxisArray
flatten(last_dim::Type{Val{N}}, As::AxisArray...) -> AxisArray
flatten(last_dim::Type{Val{N}}, labels::Tuple, As::AxisArray...) -> AxisArray
Concatenates AxisArrays with N equal leading axes into a single AxisArray.
All additional axes in any of the arrays are flattened into a single additional
CategoricalVector{Tuple} axis.
### Arguments
* `::Type{Val{N}}`: the greatest common dimension to share between all input
arrays. The remaining axes are flattened. All N axes must be common
to each input array, at the same dimension. Values from 0 up to the
minimum number of dimensions across all input arrays are allowed.
* `labels::Tuple`: (optional) a label for each AxisArray in As which is used in the flat
axis
* `As::AxisArray...`: AxisArrays to be flattened together.
"""
@generated function flatten{N, AN, LType}(::Type{Val{N}}, labels::NTuple{AN, LType}, As::Vararg{AxisArray, AN})
if N < 0
throw(ArgumentError("flatten dimension N must be at least 0"))
end
if N > minimum(ndims.(As))
throw(ArgumentError(
"""
flatten dimension N must not be greater than the maximum number of dimensions
across all input arrays
"""
))
end
flat_dim = Val{N + 1}
flat_dim_int = Int(N) + 1
common_axes, trailing_axes = zip(_splitall(Val{N}, axisparams.(As)...)...)
foreach(_check_common_axes, zip(common_axes...))
new_common_axes = first(common_axes)
flat_axis_eltype = _flat_axis_eltype(LType, trailing_axes)
flat_axis_type = CategoricalVector{flat_axis_eltype, Vector{flat_axis_eltype}}
new_axes_type = Tuple{new_common_axes..., Axis{:flat, flat_axis_type}}
new_eltype = Base.promote_eltype(As...)
quote
common_axes, trailing_axes = zip(_splitall(Val{N}, axes.(As)...)...)
for common_axis_tuple in zip(common_axes...)
if !isempty(common_axis_tuple)
for common_axis in common_axis_tuple[2:end]
if !all(axisvalues(common_axis) .== axisvalues(common_axis_tuple[1]))
throw(ArgumentError(
"""
Leading common axes must be identical across
all input arrays"""
))
end
end
end
end
array_data = cat($flat_dim, _reshapeall($flat_dim, As...)...)
axis_array_type = AxisArray{
$new_eltype,
$flat_dim_int,
Array{$new_eltype, $flat_dim_int},
$new_axes_type
}
new_axes = (
first(common_axes)...,
Axis{:flat, $flat_axis_type}($flat_axis_type(_flatten_axes(labels, trailing_axes))),
)
return axis_array_type(array_data, new_axes)
end
end