-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathaxes.jl
More file actions
359 lines (331 loc) · 14.6 KB
/
axes.jl
File metadata and controls
359 lines (331 loc) · 14.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
"""
axes_types(::Type{T}) -> Type{Tuple{Vararg{AbstractUnitRange{Int}}}}
axes_types(::Type{T}, dim) -> Type{AbstractUnitRange{Int}}
Returns the type of each axis for the `T`, or the type of of the axis along dimension `dim`.
"""
@inline axes_types(x, dim) = axes_types(x, to_dims(x, dim))
@inline function axes_types(x, dim::StaticInt{D}) where {D}
if D > ndims(x)
return SOneTo{1}
else
return field_type(axes_types(x), dim)
end
end
@inline function axes_types(x, dim::Int)
if dim > ndims(x)
return SOneTo{1}
else
return fieldtype(axes_types(x), dim)
end
end
axes_types(x) = axes_types(typeof(x))
axes_types(::Type{T}) where {T<:Array} = NTuple{ndims(T),OneTo{Int}}
@inline function axes_types(::Type{T}) where {T}
if is_forwarding_wrapper(T)
return axes_types(parent_type(T))
else
return NTuple{ndims(T),OptionallyStaticUnitRange{One,Int}}
end
end
axes_types(::Type{<:LinearIndices{N,R}}) where {N,R} = R
axes_types(::Type{<:CartesianIndices{N,R}}) where {N,R} = R
function axes_types(@nospecialize T::Type{<:VecAdjTrans})
Tuple{SOneTo{1}, fieldtype(axes_types(parent_type(T)), 1)}
end
function axes_types(@nospecialize T::Type{<:MatAdjTrans})
Ax = axes_types(parent_type(T))
Tuple{fieldtype(Ax, 2), fieldtype(Ax, 1)}
end
function axes_types(::Type{T}) where {T<:PermutedDimsArray}
eachop_tuple(field_type, to_parent_dims(T), axes_types(parent_type(T)))
end
axes_types(T::Type{<:Base.IdentityUnitRange}) = Tuple{T}
axes_types(::Type{<:Base.Slice{I}}) where {I} = Tuple{Base.IdentityUnitRange{I}}
axes_types(::Type{<:Base.Slice{I}}) where {I<:Base.IdentityUnitRange} = Tuple{I}
function axes_types(::Type{T}) where {T<:AbstractRange}
if known_length(T) === nothing
return Tuple{OneTo{Int}}
else
return Tuple{SOneTo{known_length(T)}}
end
end
axes_types(::Type{T}) where {T<:ReshapedArray} = NTuple{ndims(T),OneTo{Int}}
function _sub_axis_type(::Type{PA}, ::Type{I}, dim::StaticInt{D}) where {I<:Tuple,PA,D}
IT = field_type(I, dim)
if IT <: Base.Slice{Base.OneTo{Int}}
# this helps workaround slices over statically sized dimensions
axes_types(field_type(PA, dim), static(1))
else
axes_types(IT, static(1))
end
end
@inline function axes_types(@nospecialize T::Type{<:SubArray})
return eachop_tuple(_sub_axis_type, to_parent_dims(T), axes_types(parent_type(T)), fieldtype(T, :indices))
end
function axes_types(::Type{T}) where {T<:ReinterpretArray}
eachop_tuple(_non_reshaped_axis_type, ntuple(static, StaticInt(ndims(T))), T)
end
function _non_reshaped_axis_type(::Type{A}, d::StaticInt{D}) where {A,D}
paxis = axes_types(parent_type(A), d)
if D === 1
if known_length(paxis) === nothing
return paxis
else
return SOneTo{div(known_length(paxis) * sizeof(eltype(parent_type(A))), sizeof(eltype(A)))}
end
else
return paxis
end
end
function axes_types(::Type{A}) where {T,N,S,A<:Base.ReshapedReinterpretArray{T,N,S}}
if sizeof(S) > sizeof(T)
return merge_tuple_type(Tuple{SOneTo{div(sizeof(S), sizeof(T))}}, axes_types(parent_type(A)))
elseif sizeof(S) < sizeof(T)
P = parent_type(A)
return eachop_tuple(field_type, tail(ntuple(static, StaticInt(ndims(P)))), axes_types(P))
else
return axes_types(parent_type(A))
end
end
# FUTURE NOTE: we avoid `SOneTo(1)` when `axis(A, dim::Int)``. This is inended to decreases
# breaking changes for this adopting this method to situations where they clearly benefit
# from the propagation of static axes. This creates the somewhat awkward situation of
# conditionally typed (but inferrable) axes. It also means we can't depend on constant
# propagation to preserve statically sized axes. This should probably be addressed before
# merging into Base Julia.
"""
axes(A) -> Tuple{Vararg{AbstractUnitRange{Int}}}
axes(A, dim) -> AbstractUnitRange{Int}
Returns the axis associated with each dimension of `A` or dimension `dim`.
`ArrayInterface.axes(::AbstractArray)` behaves nearly identical to `Base.axes` with the
exception of a handful of types replace `Base.OneTo{Int}` with `ArrayInterface.SOneTo`. For
example, the axis along the first dimension of `Transpose{T,<:AbstractVector{T}}` and
`Adjoint{T,<:AbstractVector{T}}` can be represented by `SOneTo(1)`. Similarly,
`Base.ReinterpretArray`'s first axis may be statically sized.
"""
@inline axes(A) = Base.axes(A)
axes(A::ReshapedArray) = Base.axes(A)
@inline function axes(x::Union{MatAdjTrans,PermutedDimsArray})
map(GetIndex{false}(axes(parent(x))), to_parent_dims(x))
end
axes(A::VecAdjTrans) = (SOneTo{1}(), getfield(axes(parent(A)), 1))
@inline axes(x::SubArray) = flatten_tuples(map(Base.Fix1(_sub_axes, x), sub_axes_map(typeof(x))))
@inline _sub_axes(x::SubArray, axis::SOneTo) = axis
_sub_axes(x::SubArray, ::StaticInt{index}) where {index} = axes(getfield(x.indices, index))
@inline axes(A, dim) = _axes(A, to_dims(A, dim))
@inline _axes(A, dim::Int) = dim > ndims(A) ? OneTo(1) : getfield(axes(A), dim)
@inline function _axes(A, ::StaticInt{dim}) where {dim}
dim > ndims(A) ? SOneTo{1}() : getfield(axes(A), dim)
end
@inline function axes(A::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S}
if sizeof(S) > sizeof(T)
return (SOneTo(div(sizeof(S), sizeof(T))), axes(parent(A))...)
elseif sizeof(S) < sizeof(T)
return tail(axes(parent(A)))
else
return axes(parent(A))
end
end
@inline function axes(A::Base.ReshapedReinterpretArray{T,N,S}, dim) where {T,N,S}
d = to_dims(A, dim)
if sizeof(S) > sizeof(T)
if d == 1
return SOneTo(div(sizeof(S), sizeof(T)))
else
return axes(parent(A), d - static(1))
end
elseif sizeof(S) < sizeof(T)
return axes(parent(A), d - static(1))
else
return axes(parent(A), d)
end
end
"""
LazyAxis{N}(parent::AbstractArray)
A lazy representation of `axes(parent, N)`.
"""
struct LazyAxis{N,P} <: AbstractUnitRange{Int}
parent::P
function LazyAxis{N}(parent::P) where {N,P}
N > 0 && return new{N::Int,P}(parent)
throw_dim_error(parent, N)
end
@inline LazyAxis{:}(parent::P) where {P} = new{ifelse(ndims(P) === 1, 1, :),P}(parent)
end
@inline Base.parent(x::LazyAxis{N,P}) where {N,P} = axes(getfield(x, :parent), static(N))
@inline Base.parent(x::LazyAxis{:,P}) where {P} = eachindex(IndexLinear(), getfield(x, :parent))
@inline parent_type(::Type{LazyAxis{N,P}}) where {N,P} = axes_types(P, static(N))
# TODO this approach to parent_type(::Type{LazyAxis{:}}) is a bit hacky. Something like
# LabelledArrays has a linear set of symbolic keys, which could be propagated through
# `to_indices` for key based indexing. However, there currently isn't a good way of handling
# that when the linear indices aren't linearly accessible through a child array (e.g, adjoint)
# For now we just make sure the linear elements are accurate.
parent_type(::Type{LazyAxis{:,P}}) where {P<:Array} = OneTo{Int}
@inline function parent_type(::Type{LazyAxis{:,P}}) where {P}
if known_length(P) === nothing
return OptionallyStaticUnitRange{StaticInt{1},Int}
else
return SOneTo{known_length(P)}
end
end
Base.keys(x::LazyAxis) = keys(parent(x))
Base.IndexStyle(T::Type{<:LazyAxis}) = IndexStyle(parent_type(T))
ArrayInterfaceCore.can_change_size(@nospecialize T::Type{<:LazyAxis}) = can_change_size(fieldtype(T, :parent))
ArrayInterfaceCore.known_first(::Type{<:LazyAxis{N,P}}) where {N,P} = known_offsets(P, static(N))
ArrayInterfaceCore.known_first(::Type{<:LazyAxis{:,P}}) where {P} = 1
@inline function Base.first(x::LazyAxis{N})::Int where {N}
if ArrayInterfaceCore.known_first(x) === nothing
return Int(offsets(getfield(x, :parent), StaticInt(N)))
else
return Int(known_first(x))
end
end
@inline Base.first(x::LazyAxis{:})::Int = Int(offset1(getfield(x, :parent)))
ArrayInterfaceCore.known_last(::Type{LazyAxis{N,P}}) where {N,P} = known_last(axes_types(P, static(N)))
ArrayInterfaceCore.known_last(::Type{LazyAxis{:,P}}) where {P} = known_length(P)
Base.last(x::LazyAxis) = _last(known_last(x), x)
_last(::Nothing, x::LazyAxis{:}) = lastindex(getfield(x, :parent))
_last(::Nothing, x::LazyAxis{N}) where {N} = lastindex(getfield(x, :parent), N)
_last(N::Int, x) = N
known_length(::Type{<:LazyAxis{:,P}}) where {P} = known_length(P)
known_length(::Type{<:LazyAxis{N,P}}) where {N,P} = known_size(P, static(N))
@inline Base.length(x::LazyAxis{:}) = Base.length(getfield(x, :parent))
@inline Base.length(x::LazyAxis{N}) where {N} = Base.size(getfield(x, :parent), N)
Base.axes(x::LazyAxis) = (Base.axes1(x),)
Base.axes1(x::LazyAxis) = x
Base.axes(x::Slice{<:LazyAxis}) = (Base.axes1(x),)
# assuming that lazy loaded params like dynamic length from `size(::Array, dim)` are going
# be used again later with `Slice{LazyAxis}`, we quickly load indices
Base.axes1(x::Slice{LazyAxis{N,A}}) where {N,A} = indices(getfield(x.indices, :parent), StaticInt(N))
Base.axes1(x::Slice{LazyAxis{:,A}}) where {A} = indices(getfield(x.indices, :parent))
Base.to_shape(x::LazyAxis) = Base.length(x)
@propagate_inbounds function Base.getindex(x::LazyAxis, i::CanonicalInt)
@boundscheck checkindex(Bool, x, i) || throw(BoundsError(x, i))
return Int(i)
end
@propagate_inbounds function Base.getindex(x::LazyAxis, s::StepRange{<:Integer})
@boundscheck checkbounds(x, s)
range(Int(first(x) + s.start-1), step=Int(step(s)), length=Int(length(s)))
end
@propagate_inbounds Base.getindex(x::LazyAxis, i::AbstractUnitRange{<:Integer}) = parent(x)[i]
Base.show(io::IO, x::LazyAxis{N}) where {N} = print(io, "LazyAxis{$N}($(parent(x))))")
"""
lazy_axes(x)
Produces a tuple of axes where each axis is constructed lazily. If an axis of `x` is already
constructed or it is simply retrieved.
"""
@inline lazy_axes(x) = lazy_axes(x, ntuple(static, StaticInt(ndims(x))))
lazy_axes(x::Union{LinearIndices,CartesianIndices,AbstractRange}) = axes(x)
@inline function lazy_axes(x::PermutedDimsArray, ::StaticInt{N}) where {N}
N <= ndims(x) ? lazy_axes(parent(x), getfield(to_parent_dims(x), N)) : SOneTo{1}()
end
lazy_axes(x::Union{Adjoint,Transpose}, ::StaticInt{1}) = lazy_axes(parent(x), StaticInt(2))
lazy_axes(x::Union{Adjoint,Transpose}, ::StaticInt{2}) = lazy_axes(parent(x), StaticInt(1))
lazy_axes(x::AbstractRange, ::StaticInt{1}) = Base.axes1(x)
lazy_axes(x, ::Colon) = LazyAxis{:}(x)
lazy_axes(x, ::StaticInt{dim}) where {dim} = ndims(x) < dim ? SOneTo{1}() : LazyAxis{dim}(x)
@inline lazy_axes(x, dims::Tuple) = map(Base.Fix1(lazy_axes, x), dims)
"""
has_index_labels(x) -> Bool
Returns `true` if `x` has has any index labels. If [`index_labels`](@ref) returns a tuple of
`nothing`, this will be `false`.
See also: [`index_labels`](@ref)
"""
has_index_labels(x) = _any_labels(index_labels(x))
function has_index_labels(x::Union{Base.NonReshapedReinterpretArray,Transpose,Adjoint,PermutedDimsArray,Symmetric,Hermitian})
has_index_labels(parent(x))
end
function has_index_labels(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S}
if has_index_labels(parent(x))
true
else
size1 = div(sizeof(S), sizeof(T))
size1 > 1 && size1 === fieldcount(S)
end
end
function has_index_labels(x::SubArray)
if has_index_labels(parent(x))
return true
else
inds = x.indices
for i in 1:nfields(inds)
has_index_labels(getfield(inds, i)) && return true
end
return false
end
end
_any_labels(@nospecialize labels::Tuple{Vararg{Nothing}}) = false
_any_labels(@nospecialize labels::Tuple{Vararg{Any}}) = true
"""
index_labels(x)
index_labels(x, dim)
Returns a tuple of labels assigned to each axis or a collection of labels corresponding to
each index along `dim` of `x`. Default is to simply return `nothing`.
See also: [`has_index_labels`](@ref)
"""
index_labels(x, dim) = index_labels(x, to_dims(x, dim))
index_labels(@nospecialize x::Number) = ()
@inline function index_labels(x, dim::CanonicalInt)
dim > ndims(x) ? nothing : getfield(index_labels(x), Int(dim))
end
@inline function index_labels(x)
if is_forwarding_wrapper(x)
index_labels(buffer(x))
else
ntuple(Returns(nothing), Val{ndims(x)}())
end
end
function index_labels(x::Union{MatAdjTrans,PermutedDimsArray})
map(GetIndex{false}(index_labels(parent(x))), to_parent_dims(x))
end
index_labels(x::VecAdjTrans) = (nothing, getfield(index_labels(parent(x)), 1))
function index_labels(x::SubArray)
labels = index_labels(parent(x))
inds = x.indices
info = IndicesInfo(x)
pdims = parentdims(info)
cdims = childdims(info)
flatten_tuples(ntuple(Val{nfields(pdims)}()) do i
pdim_i = getfield(pdims, i)
cdim_i = getfield(cdims, i)
index = getfield(inds, i)
if pdim_i isa Tuple || cdim_i isa Tuple # no direct mapping to parent axes
index_labels(index)
elseif cdim_i === 0 # integer indexing drops axes
()
elseif pdim_i === 0 # trailing dimension
nothing
elseif index isa Base.Slice # index into labels where there is direct mapping to parent axis
(getfield(labels, pdim_i),)
else
labels_i = getfield(labels, pdim_i)
labels_i === nothing ? index_labels(index) : (@inbounds(labels_i[index]),)
end
end)
end
index_labels(x::Union{LinearIndices,CartesianIndices}) = map(first ∘ index_labels, x.indices)
index_labels(x::Union{Symmetric,Hermitian}) = index_labels(parent(x))
index_labels(@nospecialize(x::LazyAxis{:})) = (nothing,)
index_labels(x::LazyAxis{N}) where {N} = (getfield(index_labels(getfield(x, :parent)), N),)
@inline @inline function index_labels(x::Base.NonReshapedReinterpretArray{T,N,S}) where {T,N,S}
if sizeof(T) === sizeof(S)
return index_labels(parent(x))
else
return (nothing, Base.tail(index_labels(parent(x)))...)
end
end
function index_labels(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S}
_reinterpret_index_labels(div(StaticInt(sizeof(S)), StaticInt(sizeof(T))), x)
end
@inline function _reinterpreted_fieldnames(@nospecialize T::Type{<:Base.ReshapedReinterpretArray})
S = eltype(parent_type(T))
isstructtype(S) ? fieldnames(S) : ()
end
function _reinterpret_index_labels(s::StaticInt{N}, x::Base.ReshapedReinterpretArray) where {N}
__reinterpret_index_labels(s, _reinterpreted_fieldnames(typeof(x)), index_labels(parent(x)))
end
@inline function __reinterpret_index_labels(::StaticInt{N}, fields::NTuple{M,Symbol}, ks::Tuple) where {N,M}
N === M ? (fields, ks...,) : (nothing, ks...,)
end
_reinterpret_index_labels(::StaticInt{1}, x::Base.ReshapedReinterpretArray) = index_labels(parent(x))
_reinterpret_index_labels(::StaticInt{0}, x::Base.ReshapedReinterpretArray) = Base.tail(index_labels(parent(x)))