-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathdimensions.jl
More file actions
128 lines (116 loc) · 7.01 KB
/
dimensions.jl
File metadata and controls
128 lines (116 loc) · 7.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
###
### define wrapper with ArrayInterface.dimnames
###
@testset "order_named_inds" begin
n1 = (static(:x),)
n2 = (n1..., static(:y))
n3 = (n2..., static(:z))
@test @inferred(ArrayInterface.find_all_dimnames(n1, (), (), :)) == ()
@test @inferred(ArrayInterface.find_all_dimnames(n1, (static(:x),), (2,), :)) == (2,)
@test @inferred(ArrayInterface.find_all_dimnames(n2, (static(:x),), (2,), :)) == (2, :)
@test @inferred(ArrayInterface.find_all_dimnames(n2, (static(:y),), (2,), :)) == (:, 2)
@test @inferred(ArrayInterface.find_all_dimnames(n2, (static(:y), static(:x)), (20, 30), :)) == (30, 20)
@test @inferred(ArrayInterface.find_all_dimnames(n2, (static(:x), static(:y)), (30, 20), :)) == (30, 20)
@test @inferred(ArrayInterface.find_all_dimnames(n3, (static(:x), static(:y)), (30, 20), :)) == (30, 20, :)
@test_throws ErrorException ArrayInterface.find_all_dimnames(n2, (static(:x), static(:y), static(:z)), (30, 20, 40), :)
end
@testset "ArrayInterface.dimnames" begin
d = (static(:x), static(:y))
x = NamedDimsWrapper(d, ones(Int64, 2, 2))
y = NamedDimsWrapper((static(:x),), ones(2))
z = NamedDimsWrapper((:x, static(:y)), ones(2))
r1 = reinterpret(Int8, x)
r2 = reinterpret(reshape, Int8, x)
r3 = reinterpret(reshape, Complex{Int}, x)
r4 = reinterpret(reshape, Float64, x)
w = Wrapper(x)
dnums = ntuple(+, length(d))
lz2 = ArrayInterface.lazy_axes(x)[2]
lzslice = ArrayInterface.LazyAxis{:}(x)
@test @inferred(ArrayInterface.has_dimnames(x)) == true
@test @inferred(ArrayInterface.has_dimnames(z)) == true
@test @inferred(ArrayInterface.has_dimnames(ones(2, 2))) == false
@test @inferred(ArrayInterface.has_dimnames(Array{Int,2})) == false
@test @inferred(ArrayInterface.has_dimnames(typeof(x))) == true
@test @inferred(ArrayInterface.has_dimnames(typeof(view(x, :, 1, :)))) == true
@test @inferred(ArrayInterface.dimnames(x)) === d
@test @inferred(ArrayInterface.dimnames(lz2)) === (static(:y),)
@test @inferred(ArrayInterface.dimnames(lzslice)) === (static(:x),)
@test @inferred(ArrayInterface.dimnames(w)) === d
@test @inferred(ArrayInterface.dimnames(r1)) === d
@test @inferred(ArrayInterface.dimnames(r2)) === (static(:_), d...)
@test @inferred(ArrayInterface.dimnames(r3)) === Base.tail(d)
@test @inferred(ArrayInterface.dimnames(r4)) === d
@test @inferred(ArrayInterface.ArrayInterface.dimnames(z)) === (:x, static(:y))
@test @inferred(ArrayInterface.dimnames(parent(x))) === (static(:_), static(:_))
@test @inferred(ArrayInterface.dimnames(reshape(x, (1, 4)))) === d
@test @inferred(ArrayInterface.dimnames(reshape(x, :))) === (static(:_),)
@test @inferred(ArrayInterface.dimnames(x')) === reverse(d)
@test @inferred(ArrayInterface.dimnames(y')) === (static(:_), static(:x))
@test @inferred(ArrayInterface.dimnames(PermutedDimsArray(x, (2, 1)))) === reverse(d)
@test @inferred(ArrayInterface.dimnames(PermutedDimsArray(x', (2, 1)))) === d
@test @inferred(ArrayInterface.dimnames(view(x, :, 1))) === (static(:x),)
@test @inferred(ArrayInterface.dimnames(view(x, :, 1)')) === (static(:_), static(:x))
@test @inferred(ArrayInterface.dimnames(view(x, :, :, :))) === (static(:x), static(:y), static(:_))
@test @inferred(ArrayInterface.dimnames(view(x, :, 1, :))) === (static(:x), static(:_))
# multidmensional indices
@test @inferred(ArrayInterface.dimnames(view(x, ones(Int, 2, 2), 1))) === (static(:_), static(:_))
@test @inferred(ArrayInterface.dimnames(view(x, [CartesianIndex(1,1), CartesianIndex(1,1)]))) === (static(:_),)
@test @inferred(ArrayInterface.dimnames(x, ArrayInterface.One())) === static(:x)
@test @inferred(ArrayInterface.dimnames(parent(x), ArrayInterface.One())) === static(:_)
@test @inferred(ArrayInterface.known_dimnames(Iterators.flatten(1:10))) === (:_,)
@test @inferred(ArrayInterface.known_dimnames(Iterators.flatten(1:10), static(1))) === :_
# multidmensional indices
@test @inferred(ArrayInterface.known_dimnames(view(x, ones(Int, 2, 2), 1))) === (:_, :_)
@test @inferred(ArrayInterface.known_dimnames(view(x, [CartesianIndex(1,1), CartesianIndex(1,1)]))) === (:_,)
@test @inferred(ArrayInterface.known_dimnames(lz2)) === (:y,)
@test @inferred(ArrayInterface.known_dimnames(lzslice)) === (:x,)
@test @inferred(ArrayInterface.known_dimnames(z)) === (nothing, :y)
@test @inferred(ArrayInterface.known_dimnames(reshape(x, (1, 4)))) === (:x, :y)
@test @inferred(ArrayInterface.known_dimnames(r1)) === (:x, :y)
@test @inferred(ArrayInterface.known_dimnames(r2)) === (:_, :x, :y)
@test @inferred(ArrayInterface.known_dimnames(r3)) === (:y,)
@test @inferred(ArrayInterface.known_dimnames(r4)) === (:x, :y)
@test @inferred(ArrayInterface.known_dimnames(w)) === (:x, :y)
@test @inferred(ArrayInterface.known_dimnames(reshape(x, :))) === (:_,)
@test @inferred(ArrayInterface.known_dimnames(view(x, :, 1)')) === (:_, :x)
end
@testset "to_dims" begin
x = NamedDimsWrapper(static((:x, :y)), ones(2, 2))
y = NamedDimsWrapper(static((:x, :y, :a, :b, :c, :d)), ones(6))
@test @inferred(ArrayInterface.to_dims(x, :)) == Colon()
@test @inferred(ArrayInterface.to_dims(x, 1)) == 1
@testset "small case" begin
@test @inferred(ArrayInterface.to_dims(x, (:x, :y))) == (1, 2)
@test @inferred(ArrayInterface.to_dims(x, (:y, :x))) == (2, 1)
@test @inferred(ArrayInterface.to_dims(x, :x)) == 1
@test @inferred(ArrayInterface.to_dims(x, :y)) == 2
@test_throws DimensionMismatch ArrayInterface.to_dims(x, static(:z)) # not found
@test_throws DimensionMismatch ArrayInterface.to_dims(x, :z) # not found
end
@testset "large case" begin
@test @inferred(ArrayInterface.to_dims(y, :x)) == 1
@test @inferred(ArrayInterface.to_dims(y, :a)) == 3
@test @inferred(ArrayInterface.to_dims(y, :d)) == 6
@test_throws DimensionMismatch ArrayInterface.to_dims(y, :z) # not found
end
end
@testset "methods accepting ArrayInterface.dimnames" begin
d = (static(:x), static(:y))
x = NamedDimsWrapper(d, ones(2, 2))
y = NamedDimsWrapper((static(:x),), ones(2))
@test @inferred(size(x, first(d))) == size(parent(x), 1)
@test @inferred(ArrayInterface.size(y')) == (1, size(parent(x), 1))
@test @inferred(axes(x, first(d))) == axes(parent(x), 1)
@test strides(x, :x) == ArrayInterface.strides(parent(x))[1]
@test @inferred(ArrayInterface.axes_types(x, static(:x))) <: Base.OneTo{Int}
@test ArrayInterface.axes_types(x, :x) <: Base.OneTo{Int}
@test @inferred(ArrayInterface.axes_types(LinearIndices{2,NTuple{2,Base.OneTo{Int}}})) <: NTuple{2,Base.OneTo{Int}}
CI = CartesianIndices{2,Tuple{Base.OneTo{Int},UnitRange{Int}}}
@test @inferred(ArrayInterface.axes_types(CI, static(1))) <: Base.OneTo{Int}
x[x=1] = [2, 3]
@test @inferred(getindex(x, x=1)) == [2, 3]
y = NamedDimsWrapper((:x, static(:y)), ones(2, 2))
# FIXME this doesn't correctly infer the output because it can't infer
@test getindex(y, x=1) == [1, 1]
end