Skip to content

Commit 337aacc

Browse files
authored
Merge pull request #112 from ACEsuit/co/bugs
Bugfixes
2 parents 6e5ab4b + 52942d7 commit 337aacc

3 files changed

Lines changed: 15 additions & 8 deletions

File tree

src/interface.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,11 @@ _evaluate!(P, dP, basis::AbstractP4MLBasis, X) =
8080
# But ka_evaluate! can also be called with CPU arrays to enable testing
8181
# KA kernels also on the CPU.
8282

83-
evaluate!(P::AbstractGPUArray, basis::AbstractP4MLBasis, x::BATCH) =
83+
evaluate!(P::AbstractGPUArray, basis::AbstractP4MLBasis, x::BATCH, args...) =
8484
ka_evaluate!(P, basis, x)
8585

86-
evaluate_ed!(P::AbstractGPUArray, dP::AbstractGPUArray, basis::AbstractP4MLBasis, x::BATCH) =
86+
evaluate_ed!(P::AbstractGPUArray, dP::AbstractGPUArray,
87+
basis::AbstractP4MLBasis, x::BATCH, args...) =
8788
ka_evaluate_ed!(P, dP, basis, x)
8889

8990
function ka_evaluate!(P, basis::AbstractP4MLBasis, x::BATCH)
@@ -194,12 +195,12 @@ function _with_safe_alloc(fcall, args...)
194195
return fcall(outputs..., args...)
195196
end
196197

197-
function _with_safe_alloc(fcall, basis::AbstractP4MLBasis, X::BATCH)
198+
function _with_safe_alloc(fcall, basis::AbstractP4MLBasis, X::BATCH, args...)
198199
_alczero(T, args...) = fill!( similar(X, T, args...), zero(T) )
199200

200-
allocinfo = _tup_whatalloc(fcall, basis, X)
201+
allocinfo = _tup_whatalloc(fcall, basis, X, args...)
201202
outputs = ntuple(i -> _alczero(allocinfo[i]...), length(allocinfo))
202-
return fcall(outputs..., basis, X)
203+
return fcall(outputs..., basis, X, args...)
203204
end
204205

205206
# ---------------------------------------

src/transformed.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,19 @@ _valtype(::Nothing, T::Type) = T
8787
# CPU SIMD kernel
8888
#
8989

90+
_evaluate!(P, dP, tbasis::TransformedBasis, x::AbstractVector,
91+
ps::Nothing, st::Nothing) =
92+
_evaluate!(P, dP, tbasis, x,
93+
(trans = NamedTuple(), basis = NamedTuple()),
94+
(trans = NamedTuple(), basis = NamedTuple()), )
9095

91-
function _evaluate!(P, dP::Nothing,
96+
97+
function _evaluate!(P, dP,
9298
tbasis::TransformedBasis,
9399
x::AbstractVector,
94100
ps, st)
95101
nX = length(x)
102+
@assert isnothing(dP) "_evaluate! not implemented for dP != nothing"
96103

97104
@no_escape begin
98105
# [1] Stage 1 - transform the inputs

test/test_lux.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@ test_bases = [ chebyshev_basis(10),
6868
for basis in test_bases
6969
# basis = test_bases[12]
7070
@info("Lux layer test for $(typeof(basis).name.name)")
71-
local B1, B2, x
72-
local ps, st
71+
local B1, B2, x, X, ps, st
7372
nX = rand(8:16)
7473
X = [ _generate_input(basis) for _ = 1:nX ]
7574
B1 = basis(X)

0 commit comments

Comments
 (0)