-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathindexing.jl
More file actions
480 lines (420 loc) · 18.7 KB
/
indexing.jl
File metadata and controls
480 lines (420 loc) · 18.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
function known_lastindex(::Type{T}) where {T}
if known_offset1(T) === nothing || known_length(T) === nothing
return nothing
else
return known_length(T) - known_offset1(T) + 1
end
end
known_lastindex(@nospecialize x) = known_lastindex(typeof(x))
@inline static_lastindex(x) = Static.maybe_static(known_lastindex, lastindex, x)
function Base.first(x::AbstractVector, n::StaticInt)
@boundscheck n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
start = offset1(x)
@inbounds x[start:min((start - one(start)) + n, static_lastindex(x))]
end
function Base.last(x::AbstractVector, n::StaticInt)
@boundscheck n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
stop = static_lastindex(x)
@inbounds x[max(offset1(x), (stop + one(stop)) - n):stop]
end
"""
ArrayInterface.to_indices(A, I::Tuple) -> Tuple
Converts the tuple of indexing arguments, `I`, into an appropriate form for indexing into `A`.
Typically, each index should be an `Int`, `StaticInt`, a collection with values of `Int`, or a collection with values of `CartesianIndex`
This is accomplished in three steps after the initial call to `to_indices`:
# Extended help
This implementation differs from that of `Base.to_indices` in the following ways:
* `to_indices(A, I)` never results in recursive processing of `I` through
`to_indices(A, axes(A), I)`. This is avoided through the use of an internal `@generated`
method that aligns calls of `to_indices` and `to_index` based on the return values of
`ndims_index`. This is beneficial because the compiler currently does not optimize away
the increased time spent recursing through
each additional argument that needs converting. For example:
```julia
julia> x = rand(4,4,4,4,4,4,4,4,4,4);
julia> inds1 = (1, 2, 1, 2, 1, 2, 1, 2, 1, 2);
julia> inds2 = (1, CartesianIndex(1, 2), 1, CartesianIndex(1, 2), 1, CartesianIndex(1, 2), 1);
julia> inds3 = (fill(true, 4, 4), 2, fill(true, 4, 4), 2, 1, fill(true, 4, 4), 1);
julia> @btime Base.to_indices(\$x, \$inds2)
1.105 μs (12 allocations: 672 bytes)
(1, 1, 2, 1, 1, 2, 1, 1, 2, 1)
julia> @btime ArrayInterface.to_indices(\$x, \$inds2)
0.041 ns (0 allocations: 0 bytes)
(1, 1, 2, 1, 1, 2, 1, 1, 2, 1)
julia> @btime Base.to_indices(\$x, \$inds3);
340.629 ns (14 allocations: 768 bytes)
julia> @btime ArrayInterface.to_indices(\$x, \$inds3);
11.614 ns (0 allocations: 0 bytes)
```
* Recursing through `to_indices(A, axes, I::Tuple{I1,Vararg{Any}})` is intended to provide
context for processing `I1`. However, this doesn't tell use how many dimensions are
consumed by what is in `Vararg{Any}`. Using `ndims_index` to directly align the axes of
`A` with each value in `I` ensures that a `CartesiaIndex{3}` at the tail of `I` isn't
incorrectly assumed to only consume one dimension.
* `Base.to_indices` may fail to infer the returned type. This is the case for `inds2` and
`inds3` in the first bullet on Julia 1.6.4.
* Specializing by dispatch through method definitions like this:
`to_indices(::ArrayType, ::Tuple{AxisType,Vararg{Any}}, ::Tuple{::IndexType,Vararg{Any}})`
require an excessive number of hand written methods to avoid ambiguities. Furthermore, if
`AxisType` is wrapping another axis that should have unique behavior, then unique parametric
types need to also be explicitly defined.
* `to_index(axes(A, dim), index)` is called, as opposed to `Base.to_index(A, index)`. The
`IndexStyle` of the resulting axis is used to allow indirect dispatch on nested axis types
within `to_index`.
"""
to_indices(A, ::Tuple{}) = ()
@inline function to_indices(a::A, inds::I) where {A,I}
flatten_tuples(map(IndexedMappedArray(a), inds, getfield(_init_dimsmap(IndicesInfo{ndims(A)}(I)), 1)))
end
struct IndexedMappedArray{A}
a::A
end
@inline (ima::IndexedMappedArray{A})(idx::I, ::StaticInt{0}) where {A,I} = to_index(StaticInt(1):StaticInt(1), idx)
@inline (ima::IndexedMappedArray{A})(idx::I, ::Colon) where {A,I} = to_index(lazy_axes(ima.a, :), idx)
@inline (ima::IndexedMappedArray{A})(idx::I, d::StaticInt{D}) where {A,I,D} = to_index(lazy_axes(ima.a, d), idx)
@inline function (ima::IndexedMappedArray{A})(idx::AbstractArray{Bool}, dims::Tuple) where {A}
if (last(dims) == ndims(A)) && (IndexStyle(A) isa IndexLinear)
return LogicalIndex{Int}(idx)
else
return LogicalIndex(idx)
end
end
@inline (ima::IndexedMappedArray{A})(idx::CartesianIndex, ::Tuple) where {A} = getfield(idx, 1)
@inline function (ima::IndexedMappedArray{A})(idx::I, dims::Tuple) where {A,I}
to_index(CartesianIndices(lazy_axes(ima.a, dims)), idx)
end
"""
ArrayInterface.to_index([::IndexStyle, ]axis, arg) -> index
Convert the argument `arg` that was originally passed to `ArrayInterface.getindex` for the
dimension corresponding to `axis` into a form for native indexing (`Int`, Vector{Int}, etc.).
`ArrayInterface.to_index` supports passing a function as an index. This function-index is
transformed into a proper index.
```julia
julia> using ArrayInterface, Static
julia> ArrayInterface.to_index(static(1):static(10), 5)
5
julia> ArrayInterface.to_index(static(1):static(10), <(5))
static(1):4
julia> ArrayInterface.to_index(static(1):static(10), <=(5))
static(1):5
julia> ArrayInterface.to_index(static(1):static(10), >(5))
6:static(10)
julia> ArrayInterface.to_index(static(1):static(10), >=(5))
5:static(10)
```
Use of a function-index helps ensure that indices are inbounds
```julia
julia> ArrayInterface.to_index(static(1):static(10), <(12))
static(1):10
julia> ArrayInterface.to_index(static(1):static(10), >(-1))
1:static(10)
```
New axis types with unique behavior should use an `IndexStyle` trait:
```julia
to_index(axis::MyAxisType, arg) = to_index(IndexStyle(axis), axis, arg)
to_index(::MyIndexStyle, axis, arg) = ...
```
"""
to_index(x, i::Slice) = i
to_index(x, ::Colon) = indices(x)
to_index(::LinearIndices{0,Tuple{}}, ::Colon) = Slice(static(1):static(1))
to_index(::CartesianIndices{0,Tuple{}}, ::Colon) = Slice(static(1):static(1))
# logical indexing
to_index(x, i::AbstractArray{Bool}) = LogicalIndex(i)
to_index(::LinearIndices, i::AbstractArray{Bool}) = LogicalIndex{Int}(i)
# cartesian indexing
@inline to_index(x, i::CartesianIndices{0}) = i
@inline to_index(x, i::CartesianIndices) = getfield(i, :indices)
@inline to_index(x, i::CartesianIndex) = getfield(i, 1)
@inline to_index(x, i::NDIndex) = getfield(i, 1)
@inline to_index(x, i::AbstractArray{<:AbstractCartesianIndex}) = i
@inline function to_index(x, i::Base.Fix2{<:Union{typeof(<),typeof(isless)},<:Union{Base.BitInteger,StaticInt}})
static_first(x):min(_sub1(canonicalize(i.x)), static_last(x))
end
@inline function to_index(x, i::Base.Fix2{typeof(<=),<:Union{Base.BitInteger,StaticInt}})
static_first(x):min(canonicalize(i.x), static_last(x))
end
@inline function to_index(x, i::Base.Fix2{typeof(>=),<:Union{Base.BitInteger,StaticInt}})
max(canonicalize(i.x), static_first(x)):static_last(x)
end
@inline function to_index(x, i::Base.Fix2{<:Union{typeof(>),typeof(>=),typeof(<=),typeof(<),typeof(isless)},<:IndexLabel})
findall(i.f(i.x.label), first(index_labels(x)))
end
@inline function to_index(x, i::Base.Fix2{typeof(>),<:Union{Base.BitInteger,StaticInt}})
max(_add1(canonicalize(i.x)), static_first(x)):static_last(x)
end
to_index(x, i::AbstractArray{<:Union{Base.BitInteger,StaticInt}}) = i
to_index(x, @nospecialize(i::StaticInt)) = i
to_index(x, i::Integer) = Int(i)
@inline to_index(x, i) = to_index(IndexStyle(x), x, i)
# key indexing
function to_index(x, i::IndexLabel)
index = findfirst(==(getfield(i, :label)), first(index_labels(x)))
# delay throwing bounds-error if we didn't find label
index === nothing ? offset1(x) - 1 : index
end
function to_index(x, i::Union{Symbol,AbstractString,AbstractChar,Number})
index = findfirst(==(i), getfield(index_labels(x), 1))
index === nothing ? offset1(x) - 1 : index
end
# TODO there's probably a more efficient way of doing this
to_index(x, ks::AbstractArray{<:IndexLabel}) = [to_index(x, k) for k in ks]
function to_index(x, ks::AbstractArray{<:Union{Symbol,AbstractString,AbstractChar,Number}})
[to_index(x, k) for k in ks]
end
# integer indexing
function to_index(S::IndexStyle, x, i)
throw(ArgumentError(
"invalid index: $S does not support indices of type $(typeof(i)) for instances of type $(typeof(x))."
))
end
"""
unsafe_reconstruct(A, data; kwargs...)
Reconstruct `A` given the values in `data`. New methods using `unsafe_reconstruct`
should only dispatch on `A`.
"""
function unsafe_reconstruct(axis::OneTo, data; kwargs...)
if axis === data
return axis
else
return OneTo(data)
end
end
function unsafe_reconstruct(axis::UnitRange, data; kwargs...)
if axis === data
return axis
else
return UnitRange(first(data), last(data))
end
end
function unsafe_reconstruct(axis::OptionallyStaticUnitRange, data; kwargs...)
if axis === data
return axis
else
return OptionallyStaticUnitRange(static_first(data), static_last(data))
end
end
function unsafe_reconstruct(A::AbstractUnitRange, data; kwargs...)
return static_first(data):static_last(data)
end
"""
to_axes(A, inds) -> Tuple
Construct new axes given the corresponding `inds` constructed after
`to_indices(A, args) -> inds`. This method iterates through each pair of axes and
indices calling [`to_axis`](@ref).
"""
@inline function to_axes(A, inds::Tuple)
if ndims(A) === 1
return (to_axis(axes(A, 1), first(inds)),)
elseif Base.length(inds) === 1
return (to_axis(eachindex(IndexLinear(), A), first(inds)),)
else
return to_axes(A, axes(A), inds)
end
end
# drop this dimension
to_axes(A, a::Tuple, i::Tuple{<:CanonicalInt,Vararg{Any}}) = to_axes(A, _maybe_tail(a), tail(i))
to_axes(A, a::Tuple, i::Tuple{I,Vararg{Any}}) where {I} = _to_axes(StaticInt(ndims_index(I)), A, a, i)
function _to_axes(::StaticInt{1}, A, axs::Tuple, inds::Tuple)
return (to_axis(_maybe_first(axs), first(inds)), to_axes(A, _maybe_tail(axs), tail(inds))...)
end
@propagate_inbounds function _to_axes(::StaticInt{N}, A, axs::Tuple, inds::Tuple) where {N}
axes_front, axes_tail = Base.IteratorsMD.split(axs, Val(N))
if IndexStyle(A) === IndexLinear()
axis = to_axis(LinearIndices(axes_front), getfield(inds, 1))
else
axis = to_axis(CartesianIndices(axes_front), getfield(inds, 1))
end
return (axis, to_axes(A, axes_tail, tail(inds))...)
end
to_axes(A, ::Tuple{Ax,Vararg{Any}}, ::Tuple{}) where {Ax} = ()
to_axes(A, ::Tuple{}, ::Tuple{}) = ()
_maybe_first(::Tuple{}) = static(1):static(1)
_maybe_first(t::Tuple) = first(t)
_maybe_tail(::Tuple{}) = ()
_maybe_tail(t::Tuple) = tail(t)
"""
to_axis(old_axis, index) -> new_axis
Construct an `new_axis` for a newly constructed array that corresponds to the
previously executed `to_index(old_axis, arg) -> index`. `to_axis` assumes that
`index` has already been confirmed to be in bounds. The underlying indices of
`new_axis` begins at one and extends the length of `index` (i.e., one-based indexing).
"""
@inline function to_axis(axis, inds)
if !can_change_size(axis) &&
(known_length(inds) !== nothing && known_length(axis) === known_length(inds))
return axis
else
return to_axis(IndexStyle(axis), axis, inds)
end
end
# don't need to check size b/c slice means it's the entire axis
@inline function to_axis(axis, inds::Slice)
if can_change_size(axis)
return copy(axis)
else
return axis
end
end
to_axis(S::IndexLinear, axis, inds) = StaticInt(1):length(inds)
"""
ArrayInterface.getindex(A, args...)
Retrieve the value(s) stored at the given key or index within a collection. Creating
another instance of `ArrayInterface.getindex` should only be done by overloading `A`.
Changing indexing based on a given argument from `args` should be done through,
[`to_index`](@ref), or [`to_axis`](@ref).
"""
function getindex(A, args...)
inds = to_indices(A, args)
@boundscheck checkbounds(A, inds...)
unsafe_getindex(A, inds...)
end
@propagate_inbounds function getindex(A; kwargs...)
inds = to_indices(A, find_all_dimnames(dimnames(A), static(keys(kwargs)), Tuple(values(kwargs)), :))
@boundscheck checkbounds(A, inds...)
unsafe_getindex(A, inds...)
end
@propagate_inbounds getindex(x::Tuple, i::Int) = getfield(x, i)
@propagate_inbounds getindex(x::Tuple, ::StaticInt{i}) where {i} = getfield(x, i)
## unsafe_getindex ##
function unsafe_getindex(a::A) where {A}
is_forwarding_wrapper(A) || throw(MethodError(unsafe_getindex, (A,)))
unsafe_getindex(parent(a))
end
# TODO Need to manage index transformations between nested layers of arrays
function unsafe_getindex(a::A, i::CanonicalInt) where {A}
if IndexStyle(A) === IndexLinear()
is_forwarding_wrapper(A) || throw(MethodError(unsafe_getindex, (A, i)))
return unsafe_getindex(parent(a), i)
else
return unsafe_getindex(a, _to_cartesian(a, i)...)
end
end
function unsafe_getindex(a::A, i::CanonicalInt, ii::Vararg{CanonicalInt}) where {A}
if IndexStyle(A) === IndexLinear()
return unsafe_getindex(a, _to_linear(a, (i, ii...)))
else
is_forwarding_wrapper(A) || throw(MethodError(unsafe_getindex, (A, i)))
return unsafe_getindex(parent(a), i, ii...)
end
end
unsafe_getindex(a, i::Vararg{Any}) = unsafe_get_collection(a, i)
unsafe_getindex(A::Array) = Base.arrayref(false, A, 1)
unsafe_getindex(A::Array, i::CanonicalInt) = Base.arrayref(false, A, Int(i))
@inline function unsafe_getindex(A::Array, i::CanonicalInt, ii::Vararg{CanonicalInt})
unsafe_getindex(A, _to_linear(A, (i, ii...)))
end
unsafe_getindex(A::LinearIndices, i::CanonicalInt) = Int(i)
unsafe_getindex(A::CartesianIndices{N}, ii::Vararg{CanonicalInt,N}) where {N} = CartesianIndex(ii...)
unsafe_getindex(A::CartesianIndices, ii::Vararg{CanonicalInt}) =
unsafe_getindex(A, Base.front(ii)...)
unsafe_getindex(A::CartesianIndices, i::CanonicalInt) = @inbounds(A[i])
unsafe_getindex(A::ReshapedArray, i::CanonicalInt) = @inbounds(parent(A)[i])
function unsafe_getindex(A::ReshapedArray, i::CanonicalInt, ii::Vararg{CanonicalInt})
@inbounds(parent(A)[_to_linear(A, (i, ii...))])
end
unsafe_getindex(A::SubArray, i::CanonicalInt) = @inbounds(A[i])
unsafe_getindex(A::SubArray, i::CanonicalInt, ii::Vararg{CanonicalInt}) = @inbounds(A[i, ii...])
# This is based on Base._unsafe_getindex from https://github.com/JuliaLang/julia/blob/c5ede45829bf8eb09f2145bfd6f089459d77b2b1/base/multidimensional.jl#L755.
#=
unsafe_get_collection(A, inds)
Returns a collection of `A` given `inds`. `inds` is assumed to have been bounds-checked.
=#
function unsafe_get_collection(A, inds)
axs = to_axes(A, inds)
dest = similar(A, axs)
if map(length, axes(dest)) == map(length, axs)
Base._unsafe_getindex!(dest, A, inds...)
else
Base.throw_checksize_error(dest, axs)
end
return dest
end
_ints2range(x::CanonicalInt) = x:x
_ints2range(x::AbstractRange) = x
# apply _ints2range to front N elements
_ints2range_front(::Val{N}, ind, inds...) where {N} =
(_ints2range(ind), _ints2range_front(Val(N - 1), inds...)...)
_ints2range_front(::Val{0}, ind, inds...) = ()
_ints2range_front(::Val{0}) = ()
# get output shape with given indices
_output_shape(::CanonicalInt, inds...) = _output_shape(inds...)
_output_shape(ind::AbstractRange, inds...) = (Base.length(ind), _output_shape(inds...)...)
_output_shape(::CanonicalInt) = ()
_output_shape(x::AbstractRange) = (Base.length(x),)
@inline function unsafe_get_collection(A::CartesianIndices{N}, inds) where {N}
if (Base.length(inds) === 1 && N > 1) || stride_preserving_index(typeof(inds)) === False()
return Base._getindex(IndexStyle(A), A, inds...)
else
return reshape(
CartesianIndices(_ints2range_front(Val(N), inds...)),
_output_shape(inds...)
)
end
end
_known_first_isone(ind) = known_first(ind) !== nothing && isone(known_first(ind))
@inline function unsafe_get_collection(A::LinearIndices{N}, inds) where {N}
if Base.length(inds) === 1 && ndims_index(typeof(first(inds))) === 1
return @inbounds(eachindex(A)[first(inds)])
elseif stride_preserving_index(typeof(inds)) === True() &&
reduce_tup(&, map(_known_first_isone, inds))
# create a LinearIndices when first(ind) != 1 is imposable
return reshape(
LinearIndices(_ints2range_front(Val(N), inds...)),
_output_shape(inds...)
)
else
return Base._getindex(IndexStyle(A), A, inds...)
end
end
"""
ArrayInterface.setindex!(A, args...)
Store the given values at the given key or index within a collection.
"""
@propagate_inbounds function setindex!(A, val, args...)
can_setindex(A) || error("Instance of type $(typeof(A)) are not mutable and cannot change elements after construction.")
inds = to_indices(A, args)
@boundscheck checkbounds(A, inds...)
unsafe_setindex!(A, val, inds...)
end
@propagate_inbounds function setindex!(A, val; kwargs...)
can_setindex(A) || error("Instance of type $(typeof(A)) are not mutable and cannot change elements after construction.")
inds = to_indices(A, find_all_dimnames(dimnames(A), static(keys(kwargs)), Tuple(values(kwargs)), :))
@boundscheck checkbounds(A, inds...)
unsafe_setindex!(A, val, inds...)
end
## unsafe_setindex! ##
function unsafe_setindex!(a::A, v) where {A}
is_forwarding_wrapper(A) || throw(MethodError(unsafe_setindex!, (A, v)))
return unsafe_setindex!(parent(a), v)
end
# TODO Need to manage index transformations between nested layers of arrays
function unsafe_setindex!(a::A, v, i::CanonicalInt) where {A}
if IndexStyle(A) === IndexLinear()
is_forwarding_wrapper(A) || throw(MethodError(unsafe_setindex!, (A, v, i)))
return unsafe_setindex!(parent(a), v, i)
else
return unsafe_setindex!(a, v, _to_cartesian(a, i)...)
end
end
function unsafe_setindex!(a::A, v, i::CanonicalInt, ii::Vararg{CanonicalInt}) where {A}
if IndexStyle(A) === IndexLinear()
return unsafe_setindex!(a, v, _to_linear(a, (i, ii...)))
else
is_forwarding_wrapper(A) || throw(MethodError(unsafe_setindex!, (A, v, i, ii...)))
return unsafe_setindex!(parent(a), v, i, ii...)
end
end
function unsafe_setindex!(A::Array{T}, v) where {T}
Base.arrayset(false, A, convert(T, v)::T, 1)
end
function unsafe_setindex!(A::Array{T}, v, i::CanonicalInt) where {T}
return Base.arrayset(false, A, convert(T, v)::T, Int(i))
end
unsafe_setindex!(a, v, i::Vararg{Any}) = unsafe_set_collection!(a, v, i)
# This is based on Base._unsafe_setindex!.
#=
unsafe_set_collection!(A, val, inds)
Sets `inds` of `A` to `val`. `inds` is assumed to have been bounds-checked.
=#
unsafe_set_collection!(A, v, i) = Base._unsafe_setindex!(IndexStyle(A), A, v, i...)