This repository was archived by the owner on Aug 17, 2021. It is now read-only.
forked from JuliaArrays/AxisArrays.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathindexing.jl
More file actions
314 lines (277 loc) · 14.2 KB
/
indexing.jl
File metadata and controls
314 lines (277 loc) · 14.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
const Idx = Union{Real,Colon,AbstractArray{Int}}
using Base: ViewIndex, @propagate_inbounds, tail
immutable Value{T}
val::T
tol::T
end
Value(x, tol=Base.rtoldefault(typeof(x))*abs(x)) = Value(promote(x,tol)...)
atvalue(x; rtol=Base.rtoldefault(typeof(x)), atol=zero(x)) = Value(x, atol+rtol*abs(x))
# For throwing a BoundsError with a Value index, we need to define the following
# (note that we could inherit them for free, were Value <: Number)
Base.start(::Value) = false
Base.next(x::Value, state) = (x, true)
Base.done(x::Value, state) = state
# How to show Value objects (e.g. in a BoundsError)
Base.show(io::IO, v::Value) =
print(io, string("Value(", v.val, ", tol=", v.tol, ")"))
# Defer IndexStyle to the wrapped array
@compat Base.IndexStyle{T,N,D,Ax}(::Type{AxisArray{T,N,D,Ax}}) = IndexStyle(D)
# Simple scalar indexing where we just set or return scalars
@propagate_inbounds Base.getindex(A::AxisArray, idxs::Int...) = A.data[idxs...]
@propagate_inbounds Base.setindex!(A::AxisArray, v, idxs::Int...) = (A.data[idxs...] = v)
# Cartesian iteration
Base.eachindex(A::AxisArray) = eachindex(A.data)
"""
reaxis(A::AxisArray, I...)
This internal function determines the new set of axes that are constructed upon
indexing with I.
"""
reaxis(A::AxisArray, I::Idx...) = _reaxis(make_axes_match(axes(A), I), I)
# Ensure the number of axes matches the number of indexing dimensions
@inline make_axes_match(axs, idxs) = _make_axes_match((), axs, Base.index_ndims(idxs...))
# Move the axes into newaxes, until we run out of both simultaneously
@inline _make_axes_match(newaxes, axs::Tuple, nidxs::Tuple) =
_make_axes_match((newaxes..., axs[1]), tail(axs), tail(nidxs))
@inline _make_axes_match(newaxes, axs::Tuple{}, nidxs::Tuple{}) = newaxes
# Drop trailing axes, replacing it with a default name for the linear span
@inline _make_axes_match(newaxes, axs::Tuple, nidxs::Tuple{}) =
(maybefront(newaxes)..., _nextaxistype(newaxes)(Base.OneTo(length(newaxes[end]) * prod(map(length, axs)))))
# Insert phony singleton trailing axes
@inline _make_axes_match(newaxes, axs::Tuple{}, nidxs::Tuple) =
_make_axes_match((newaxes..., _nextaxistype(newaxes)(Base.OneTo(1))), (), tail(nidxs))
@inline maybefront(::Tuple{}) = ()
@inline maybefront(t::Tuple) = Base.front(t)
# Now we can reaxis without worrying about mismatched axes/indices
@inline _reaxis(axs::Tuple{}, idxs::Tuple{}) = ()
# Scalars are dropped
const ScalarIndex = @compat Union{Real, AbstractArray{<:Any, 0}}
@inline _reaxis(axs::Tuple, idxs::Tuple{ScalarIndex, Vararg{Any}}) = _reaxis(tail(axs), tail(idxs))
# Colon passes straight through
@inline _reaxis(axs::Tuple, idxs::Tuple{Colon, Vararg{Any}}) = (axs[1], _reaxis(tail(axs), tail(idxs))...)
# But arrays can add or change dimensions and accompanying axis names
@inline _reaxis(axs::Tuple, idxs::Tuple{AbstractArray, Vararg{Any}}) =
(_new_axes(axs[1], idxs[1])..., _reaxis(tail(axs), tail(idxs))...)
# Vectors simply create new axes with the same name; just subsetted by their value
@inline _new_axes{name}(ax::Axis{name}, idx::AbstractVector) = (Axis{name}(ax.val[idx]),)
# Arrays create multiple axes with _N appended to the axis name containing their indices
@generated function _new_axes{name, N}(ax::Axis{name}, idx::@compat(AbstractArray{<:Any,N}))
newaxes = Expr(:tuple)
for i=1:N
push!(newaxes.args, :($(Axis{Symbol(name, "_", i)})(indices(idx, $i))))
end
newaxes
end
# And indexing with an AxisArray joins the name and overrides the values
@generated function _new_axes{name, N}(ax::Axis{name}, idx::@compat(AxisArray{<:Any, N}))
newaxes = Expr(:tuple)
idxnames = axisnames(idx)
for i=1:N
push!(newaxes.args, :($(Axis{Symbol(name, "_", idxnames[i])})(idx.axes[$i].val)))
end
newaxes
end
@propagate_inbounds function Base.getindex(A::AxisArray, idxs::Idx...)
AxisArray(A.data[idxs...], reaxis(A, idxs...))
end
# To resolve ambiguities, we need several definitions
if VERSION >= v"0.6.0-dev.672"
using Base.AbstractCartesianIndex
@propagate_inbounds Base.view(A::AxisArray, idxs::Idx...) = AxisArray(view(A.data, idxs...), reaxis(A, idxs...))
else
@propagate_inbounds function Base.view{T,N}(A::AxisArray{T,N}, idxs::Vararg{Idx,N})
AxisArray(view(A.data, idxs...), reaxis(A, idxs...))
end
@propagate_inbounds function Base.view(A::AxisArray, idx::Idx)
AxisArray(view(A.data, idx), reaxis(A, idx))
end
@propagate_inbounds function Base.view{N}(A::AxisArray, idxs::Vararg{Idx,N})
# this should eventually be deleted, see julia #14770
AxisArray(view(A.data, idxs...), reaxis(A, idxs...))
end
end
# Setindex is so much simpler. Just assign it to the data:
@propagate_inbounds Base.setindex!(A::AxisArray, v, idxs::Idx...) = (A.data[idxs...] = v)
### Fancier indexing capabilities provided only by AxisArrays ###
@propagate_inbounds Base.getindex(A::AxisArray, idxs...) = A[to_index(A,idxs...)...]
@propagate_inbounds Base.setindex!(A::AxisArray, v, idxs...) = (A[to_index(A,idxs...)...] = v)
# Deal with lots of ambiguities here
if VERSION >= v"0.6.0-dev.672"
@propagate_inbounds Base.view(A::AxisArray, idxs::ViewIndex...) = view(A, to_index(A,idxs...)...)
@propagate_inbounds Base.view(A::AxisArray, idxs::Union{ViewIndex,AbstractCartesianIndex}...) = view(A, to_index(A,Base.IteratorsMD.flatten(idxs)...)...)
@propagate_inbounds Base.view(A::AxisArray, idxs...) = view(A, to_index(A,idxs...)...)
else
for T in (:ViewIndex, :Any)
@eval begin
@propagate_inbounds function Base.view{T,N}(A::AxisArray{T,N}, idxs::Vararg{$T,N})
view(A, to_index(A,idxs...)...)
end
@propagate_inbounds function Base.view(A::AxisArray, idx::$T)
view(A, to_index(A,idx)...)
end
@propagate_inbounds function Base.view{N}(A::AxisArray, idsx::Vararg{$T,N})
# this should eventually be deleted, see julia #14770
view(A, to_index(A,idxs...)...)
end
end
end
end
# First is indexing by named axis. We simply sort the axes and re-dispatch.
# When indexing by named axis the shapes of omitted dimensions are preserved
# TODO: should we handle multidimensional Axis indexes? It could be interpreted
# as adding dimensions in the middle of an AxisArray.
# TODO: should we allow repeated axes? As a union of indices of the duplicates?
@generated function to_index{T,N,D,Ax}(A::AxisArray{T,N,D,Ax}, I::Axis...)
dims = Int[axisdim(A, ax) for ax in I]
idxs = Expr[:(Colon()) for d = 1:N]
names = axisnames(A)
for i=1:length(dims)
idxs[dims[i]] == :(Colon()) ||
return :(throw(ArgumentError(string("multiple indices provided ",
"on axis ", $(string(names[dims[i]]))))))
idxs[dims[i]] = :(I[$i].val)
end
meta = Expr(:meta, :inline)
return :($meta; to_index(A, $(idxs...)))
end
function Base.reshape{N}(A::AxisArray, ::Type{Val{N}})
# axN, _ = Base.IteratorsMD.split(axes(A), Val{N})
# AxisArray(reshape(A.data, Val{N}), reaxis(A, Base.fill_to_length(axN, :, Val{N})...))
AxisArray(reshape(A.data, Val{N}), reaxis(A, ntuple(d->Colon(), Val{N})...))
end
### Indexing along values of the axes ###
# Default axes indexing throws an error
"""
axisindexes(ax::Axis, axis_idx) -> array_idx
axisindexes(::Type{<:AxisTrait}, axis_values, axis_idx) -> array_idx
Translate an index into an axis into an index into the underlying array.
Users can add additional indexing behaviours for custom axes or custom indices by adding
methods to this function.
## Examples
Add a method for indexing into an `Axis{name, SortedSet}`:
```julia
AxisArrays.axisindexes(::Type{Categorical}, ax::SortedSet, idx::AbstractVector) = findin(collect(ax), idx)
```
Add a method for indexing into a `Categorical` axis with a `SortedSet`:
```julia
AxisArrays.axisindexes(::Type{Categorical}, ax::AbstractVector, idx::SortedSet) = findin(ax, idx)
```
"""
axisindexes(ax, idx) = axisindexes(axistrait(ax.val), ax.val, idx)
axisindexes(::Type{Unsupported}, ax, idx) = error("elementwise indexing is not supported for axes of type $(typeof(ax))")
axisindexes(t, ax, idx) = error("cannot index $(typeof(ax)) with $(typeof(idx)); expected $(eltype(ax)) axis value or Integer index")
# Dimensional axes may be indexed directly by their elements if Non-Real and unique
# Maybe extend error message to all <: Numbers if Base allows it?
axisindexes(::Type{Dimensional}, ax::AbstractVector, idx::Real) =
throw(ArgumentError("invalid index: $idx. Use `atvalue` when indexing by value."))
function axisindexes(::Type{Dimensional}, ax::AbstractVector, idx)
idxs = searchsorted(ax, ClosedInterval(idx,idx))
length(idxs) > 1 && error("more than one datapoint lies on axis value $idx; use an interval to return all values")
idxs[1]
end
# Dimensional axes may always be indexed by value if in a Value type wrapper.
function axisindexes(::Type{Dimensional}, ax::AbstractVector, idx::Value)
idxs = searchsorted(ax, ClosedInterval(idx.val,idx.val))
length(idxs) > 1 && error("more than one datapoint lies on axis value $idx; use an interval to return all values")
if length(idxs) == 1
idxs[1]
else # it's zero
last(idxs) > 0 && abs(ax[last(idxs)] - idx.val) < idx.tol && return last(idxs)
first(idxs) <= length(ax) && abs(ax[first(idxs)] - idx.val) < idx.tol && return first(idxs)
throw(BoundsError(ax, idx))
end
end
# Dimensional axes may be indexed by intervals to select a range
axisindexes{T}(::Type{Dimensional}, ax::AbstractVector{T}, idx::ClosedInterval) = searchsorted(ax, idx)
# Or repeated intervals, which only work if the axis is a range since otherwise
# there will be a non-constant number of indices in each repetition.
# Two small tricks are used here:
# * Compute the resulting interval axis with unsafe indexing without any offset
# - Since it's a range, we can do this, and it makes the resulting axis useful
# * Snap the offsets to the nearest datapoint to avoid fencepost problems
# Adds a dimension to the result; rows represent the interval and columns are offsets.
axisindexes(::Type{Dimensional}, ax::AbstractVector, idx::RepeatedInterval) = error("repeated intervals might select a varying number of elements for non-range axes; use a repeated Range of indices instead")
function axisindexes(::Type{Dimensional}, ax::Range, idx::RepeatedInterval)
n = length(idx.offsets)
idxs = unsafe_searchsorted(ax, idx.window)
offsets = [searchsortednearest(ax, idx.offsets[i]) for i=1:n]
AxisArray(RepeatedRangeMatrix(idxs, offsets), Axis{:sub}(inbounds_getindex(ax, idxs)), Axis{:rep}(ax[offsets]))
end
# We also have special datatypes to represent intervals about indices
axisindexes(::Type{Dimensional}, ax::AbstractVector, idx::IntervalAtIndex) = searchsorted(ax, idx.window + ax[idx.index])
function axisindexes(::Type{Dimensional}, ax::Range, idx::IntervalAtIndex)
idxs = unsafe_searchsorted(ax, idx.window)
AxisArray(idxs + idx.index, Axis{:sub}(inbounds_getindex(ax, idxs)))
end
axisindexes(::Type{Dimensional}, ax::AbstractVector, idx::RepeatedIntervalAtIndexes) = error("repeated intervals might select a varying number of elements for non-range axes; use a repeated Range of indices instead")
function axisindexes(::Type{Dimensional}, ax::Range, idx::RepeatedIntervalAtIndexes)
n = length(idx.indexes)
idxs = unsafe_searchsorted(ax, idx.window)
AxisArray(RepeatedRangeMatrix(idxs, idx.indexes), Axis{:sub}(inbounds_getindex(ax, idxs)), Axis{:rep}(ax[idx.indexes]))
end
# Categorical axes may be indexed by their elements
function axisindexes(::Type{Categorical}, ax::AbstractVector, idx)
i = findfirst(ax, idx)
i == 0 && throw(ArgumentError("index $idx not found"))
i
end
# Categorical axes may be indexed by a vector of their elements
function axisindexes(::Type{Categorical}, ax::AbstractVector, idx::AbstractVector)
res = findin(ax, idx)
length(res) == length(idx) || throw(ArgumentError("index $(setdiff(idx,ax)) not found"))
res
end
# This catch-all method attempts to convert any axis-specific non-standard
# indexing types to their integer or integer range equivalents using axisindexes
# It is separate from the `Base.getindex` function to allow reuse between
# set- and get- index.
@generated function to_index{T,N,D,Ax}(A::AxisArray{T,N,D,Ax}, I...)
ex = Expr(:tuple)
n = 0
for i=1:length(I)
if axistrait(I[i]) <: Categorical && i <= length(Ax.parameters)
if I[i] <: Axis
push!(ex.args, :(axisindexes(A.axes[$i], I[$i].val)))
else
push!(ex.args, :(axisindexes(A.axes[$i], I[$i])))
end
n += 1
continue
end
if I[i] <: Idx
push!(ex.args, :(I[$i]))
n += 1
elseif I[i] <: AbstractArray{Bool}
push!(ex.args, :(find(I[$i])))
n += 1
elseif I[i] <: CartesianIndex
for j = 1:length(I[i])
push!(ex.args, :(I[$i][$j]))
end
n += length(I[i])
elseif i <= length(Ax.parameters)
if I[i] <: Axis
push!(ex.args, :(axisindexes(A.axes[$i], I[$i].val)))
else
push!(ex.args, :(axisindexes(A.axes[$i], I[$i])))
end
n += 1
else
push!(ex.args, :(error("dimension ", $i, " does not have an axis to index")))
end
end
for _=n+1:N
push!(ex.args, :(Colon()))
end
meta = Expr(:meta, :inline)
return :($meta; $ex)
end
## Extracting the full axis (name + values) from the Axis{:name} type
@inline Base.getindex{Ax<:Axis}(A::AxisArray, ::Type{Ax}) = getaxis(Ax, axes(A)...)
@inline getaxis{Ax<:Axis}(::Type{Ax}, ax::Ax, axs...) = ax
@inline getaxis{Ax<:Axis}(::Type{Ax}, ax::Axis, axs...) = getaxis(Ax, axs...)
@noinline getaxis{Ax<:Axis}(::Type{Ax}) = throw(ArgumentError("no axis of type $Ax was found"))
# Boundschecking specialization: defer to the data array.
# Note that we could unwrap AxisArrays when they are used as indices into other
# arrays within Base's to_index/to_indices methods, but that requires a bigger
# refactor to merge our to_index method with Base's.
@inline Base.checkindex(::Type{Bool}, inds::AbstractUnitRange, A::AxisArray) = Base.checkindex(Bool, inds, A.data)