Skip to content

Commit dd3fea5

Browse files
Merge pull request #479 from vchuravy/add-amdgpu-extension
Add AMDGPU.jl extension
2 parents b1ac878 + 197e7dd commit dd3fea5

File tree

4 files changed

+32
-0
lines changed

4 files changed

+32
-0
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88

99
[weakdeps]
10+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1011
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
1112
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
1213
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
@@ -21,6 +22,7 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2122
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2223

2324
[extensions]
25+
ArrayInterfaceAMDGPUExt = "AMDGPU"
2426
ArrayInterfaceBandedMatricesExt = "BandedMatrices"
2527
ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices"
2628
ArrayInterfaceCUDAExt = "CUDA"
@@ -35,6 +37,7 @@ ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore"
3537
ArrayInterfaceTrackerExt = "Tracker"
3638

3739
[compat]
40+
AMDGPU = "2"
3841
Adapt = "4"
3942
BandedMatrices = "1"
4043
BlockBandedMatrices = "0.13"

ext/ArrayInterfaceAMDGPUExt.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module ArrayInterfaceAMDGPUExt
2+
3+
using ArrayInterface
4+
using AMDGPU
5+
using LinearAlgebra
6+
7+
function ArrayInterface.lu_instance(A::ROCMatrix{T}) where {T}
8+
ipiv = ROCVector{Cint}(undef, 0)
9+
info = zero(Int)
10+
return LinearAlgebra.LU(similar(A, 0, 0), ipiv, info)
11+
end
12+
13+
ArrayInterface.device(::Type{<:AMDGPU.ROCArray}) = ArrayInterface.GPU()
14+
15+
end # module

test/gpu/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
23
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
34
CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
45
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"

test/gpu/amdgpu.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using AMDGPU
2+
using ArrayInterface
3+
using LinearAlgebra
4+
5+
using Test
6+
7+
A = ROCMatrix(Float32[1 0; 0 1])
8+
9+
# Test that lu_instance works with AMDGPU.jl ROC arrays
10+
@test isa(ArrayInterface.lu_instance(A), LU)
11+
12+
# Test that device returns GPU()
13+
@test ArrayInterface.device(typeof(A)) == ArrayInterface.GPU()

0 commit comments

Comments
 (0)