Skip to content

Commit a78c5fc

Browse files
committed
Require unsafe_wrap and reinterpret to have proper alignment
1 parent 239a2c8 commit a78c5fc

File tree

3 files changed

+66
-9
lines changed

3 files changed

+66
-9
lines changed

base/docs/helpdb/Base.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,14 @@ For example,
10631063
`reinterpret(Float32, UInt32(7))` interprets the 4 bytes corresponding to `UInt32(7)` as a
10641064
`Float32`.
10651065
1066+
!!! warning
1067+
1068+
It is not allowed to `reinterpret` an array to a element type with a larger alignment then
1069+
the alignment of the array. For a normal `Array`, this is the alignment of its element type.
1070+
For a reinterpreted array, this is the alignment of the `Array` it was reinterpreted from.
1071+
For example, `reinterpret(UInt32, UInt8[0, 0, 0, 0])` is not allowed but
1072+
`reinterpret(UInt32, reinterpret(UInt8, Float32[1.0]))` is allowed.
1073+
10661074
```jldoctest
10671075
julia> reinterpret(Float32, UInt32(7))
10681076
1.0f-44

src/array.c

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,19 @@ JL_DLLEXPORT jl_array_t *jl_reshape_array(jl_value_t *atype, jl_array_t *data,
188188
a->offset = 0;
189189
a->data = NULL;
190190
a->flags.isaligned = data->flags.isaligned;
191+
jl_array_t *owner = (jl_array_t*)jl_array_owner(data);
191192
jl_value_t *el_type = jl_tparam0(atype);
192193
assert(store_unboxed(el_type) == !data->flags.ptrarray);
193194
if (!data->flags.ptrarray) {
194195
a->elsize = jl_datatype_size(el_type);
196+
unsigned align = ((jl_datatype_t*)el_type)->layout->alignment;
197+
jl_value_t *ownerty = jl_typeof(owner);
198+
unsigned oldalign = (ownerty == (jl_value_t*)jl_string_type ? 1 :
199+
((jl_datatype_t*)jl_tparam0(ownerty))->layout->alignment);
200+
if (oldalign < align)
201+
jl_exceptionf(jl_argumenterror_type,
202+
"reinterpret from alignment %u to alignment %u not allowed",
203+
oldalign, align);
195204
a->flags.ptrarray = 0;
196205
}
197206
else {
@@ -201,7 +210,7 @@ JL_DLLEXPORT jl_array_t *jl_reshape_array(jl_value_t *atype, jl_array_t *data,
201210

202211
// if data is itself a shared wrapper,
203212
// owner should point back to the original array
204-
jl_array_data_owner(a) = jl_array_owner(data);
213+
jl_array_data_owner(a) = (jl_value_t*)owner;
205214

206215
a->flags.how = 3;
207216
a->data = data->data;
@@ -266,15 +275,22 @@ JL_DLLEXPORT jl_array_t *jl_ptr_to_array_1d(jl_value_t *atype, void *data,
266275
size_t nel, int own_buffer)
267276
{
268277
jl_ptls_t ptls = jl_get_ptls_states();
269-
size_t elsz;
270278
jl_array_t *a;
271279
jl_value_t *el_type = jl_tparam0(atype);
272280

273281
int isunboxed = store_unboxed(el_type);
274-
if (isunboxed)
282+
size_t elsz;
283+
unsigned align;
284+
if (isunboxed) {
275285
elsz = jl_datatype_size(el_type);
276-
else
277-
elsz = sizeof(void*);
286+
align = ((jl_datatype_t*)el_type)->layout->alignment;
287+
}
288+
else {
289+
align = elsz = sizeof(void*);
290+
}
291+
if (((uintptr_t)data) & (align - 1))
292+
jl_exceptionf(jl_argumenterror_type,
293+
"unsafe_wrap: pointer %p is not properly aligned to %u", data, align);
278294

279295
int ndimwords = jl_array_ndimwords(1);
280296
int tsz = JL_ARRAY_ALIGN(sizeof(jl_array_t) + ndimwords*sizeof(size_t), JL_CACHE_BYTE_ALIGNMENT);
@@ -309,7 +325,7 @@ JL_DLLEXPORT jl_array_t *jl_ptr_to_array(jl_value_t *atype, void *data,
309325
jl_value_t *_dims, int own_buffer)
310326
{
311327
jl_ptls_t ptls = jl_get_ptls_states();
312-
size_t elsz, nel = 1;
328+
size_t nel = 1;
313329
jl_array_t *a;
314330
size_t ndims = jl_nfields(_dims);
315331
wideint_t prod;
@@ -326,10 +342,18 @@ JL_DLLEXPORT jl_array_t *jl_ptr_to_array(jl_value_t *atype, void *data,
326342
jl_value_t *el_type = jl_tparam0(atype);
327343

328344
int isunboxed = store_unboxed(el_type);
329-
if (isunboxed)
345+
size_t elsz;
346+
unsigned align;
347+
if (isunboxed) {
330348
elsz = jl_datatype_size(el_type);
331-
else
332-
elsz = sizeof(void*);
349+
align = ((jl_datatype_t*)el_type)->layout->alignment;
350+
}
351+
else {
352+
align = elsz = sizeof(void*);
353+
}
354+
if (((uintptr_t)data) & (align - 1))
355+
jl_exceptionf(jl_argumenterror_type,
356+
"unsafe_wrap: pointer %p is not properly aligned to %u", data, align);
333357

334358
int ndimwords = jl_array_ndimwords(ndims);
335359
int tsz = JL_ARRAY_ALIGN(sizeof(jl_array_t) + ndimwords*sizeof(size_t), JL_CACHE_BYTE_ALIGNMENT);

test/core.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -952,7 +952,15 @@ let
952952
@test aa == a
953953
aa = unsafe_wrap(Array, pointer(a), UInt16(length(a)))
954954
@test aa == a
955+
aaa = unsafe_wrap(Array, pointer(a), (1, 1))
956+
@test size(aaa) == (1, 1)
957+
@test aaa[1] == a[1]
955958
@test_throws InexactError unsafe_wrap(Array, pointer(a), -3)
959+
# Misaligned pointer
960+
res = @test_throws ArgumentError unsafe_wrap(Array, pointer(a) + 1, length(a))
961+
@test contains(res.value.msg, "is not properly aligned to $(sizeof(Int))")
962+
res = @test_throws ArgumentError unsafe_wrap(Array, pointer(a) + 1, (1, 1))
963+
@test contains(res.value.msg, "is not properly aligned to $(sizeof(Int))")
956964
end
957965

958966
struct FooBar2515
@@ -4906,3 +4914,20 @@ type T21719{V}
49064914
end
49074915
g21719(f, goal; tol = 1e-6) = T21719(f, tol, goal)
49084916
@test isa(g21719(identity, 1.0; tol=0.1), T21719)
4917+
4918+
# reinterpret alignment requirement
4919+
let arr8 = zeros(UInt8, 16),
4920+
arr64 = zeros(UInt64, 2),
4921+
arr64_8 = reinterpret(UInt8, arr64),
4922+
arr64_i
4923+
4924+
# Not allowed to reinterpret arrays allocated as UInt8 array to a Int32 array
4925+
res = @test_throws ArgumentError reinterpret(Int32, arr8)
4926+
@test res.value.msg == "reinterpret from alignment 1 to alignment 4 not allowed"
4927+
# OK to reinterpret arrays allocated as UInt64 array to a Int64 array even though
4928+
# it is passed as a UInt8 array
4929+
arr64_i = reinterpret(Int64, arr64_8)
4930+
@test arr8 == arr64_8
4931+
arr64_i[2] = 1234
4932+
@test arr64[2] == 1234
4933+
end

0 commit comments

Comments
 (0)