Skip to content

Commit a117728

Browse files
authored
Merge pull request #111 from ACEsuit/fix_hyperdual
Add _valtype support for RadialDecay with HyperDual input
2 parents 026de7a + 413350e commit a117728

5 files changed

Lines changed: 84 additions & 72 deletions

File tree

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1212
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1313
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1414
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
15+
HyperDualNumbers = "50ceba7f-c3ee-5a84-a6e8-3ad40456ec97"
1516
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1617
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1718
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
@@ -25,6 +26,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2526
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2627
WithAlloc = "fb1aa66a-603c-4c1d-9bc4-66947c7b08dd"
2728

29+
2830
[compat]
2931
ACEbase = "0.4.5"
3032
BenchmarkTools = "1"
@@ -34,6 +36,7 @@ Combinatorics = "1"
3436
DiffResults = "1.1.0"
3537
ForwardDiff = "0.10, 1.0"
3638
GPUArraysCore = "0.2.0"
39+
HyperDualNumbers = "4.0.10"
3740
KernelAbstractions = "0.9.34"
3841
LinearAlgebra = "1.10"
3942
LuxCore = "1.2.0"
@@ -49,10 +52,11 @@ WithAlloc = "0.1.0"
4952
julia = "1.10"
5053

5154
[extras]
55+
HyperDualNumbers = "50ceba7f-c3ee-5a84-a6e8-3ad40456ec97"
5256
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
5357
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
5458
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5559
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5660

5761
[targets]
58-
test = ["Test", "Printf", "Optimisers", "Zygote"]
62+
test = ["Test", "Printf", "Optimisers", "Zygote", "HyperDualNumbers"]

src/atomicorbitals/radialdecay.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,12 @@ end
3232

3333
Base.length(basis::RadialDecay) = size(basis.ζ, 1)
3434

35-
_valtype(::RadialDecay, T::Type{<: Real}) = T
36-
37-
_valtype(::RadialDecay, T::Type{<: Real},
38-
ps::Union{Nothing, @NamedTuple{}}, st) = T
35+
_valtype(::RadialDecay, T::Type{<: Number}, args...) = T
3936

4037
_valtype(::RadialDecay, T::Type{<: Real},
4138
ps, st) = promote_type(T, eltype(ps.ζ), eltype(ps.D))
4239

40+
4341
_static_params(basis::RadialDecay) == basis.ζ, D = basis.D)
4442

4543
_init_luxparams(basis::RadialDecay) =

src/testing.jl

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ import Polynomials4ML: _generate_input, _generate_batch
1111
using ChainRulesCore: rrule
1212

1313
using Test, ForwardDiff, Bumper, WithAlloc
14-
14+
using HyperDualNumbers
15+
using HyperDualNumbers: Hyper
1516
using StaticArrays
16-
using LinearAlgebra: norm, dot
17+
using Random, LuxCore
18+
using LinearAlgebra: norm, dot, normalize
1719

1820
using ACEbase.Testing: print_tf, println_slim, fdtest
1921

@@ -213,4 +215,49 @@ function test_withalloc(basis::AbstractP4MLBasis;
213215
end
214216

215217

218+
# Construct a static vector of HyperDual numbers from
219+
# x : point coordinates
220+
# v : first direction
221+
# w : second direction
222+
make_hyper(x::SVector{N,T}, v::SVector{N,T}, w::SVector{N,T}) where {N,T} =
223+
SVector{N}(Hyper.(x, v, w, zero(T)))
224+
225+
# Extract components from an array of HyperDual numbers
226+
_val(A) = HyperDualNumbers.value.(A) # function value
227+
_e1(A) = HyperDualNumbers.eps1.(A) # first directional derivative
228+
_e2(A) = HyperDualNumbers.eps2.(A) # second directional derivative
229+
230+
# Project each gradient vector g in G onto direction v
231+
# G is a vector of static vectors (gradients for each basis function)
232+
_dir(G::AbstractVector{<:StaticVector{D,T}}, v::StaticVector{D,T}) where {D,T} =
233+
map(g -> dot(g, v), G)
234+
235+
function test_hyperdual_consistency(basis; rng=Random.default_rng(), rtol=1e-10, atol=1e-12)
236+
@info("HyperDual test")
237+
ps, st = LuxCore.setup(rng, basis)
238+
239+
x = _generate_input(basis)
240+
v = normalize(rand(rng, SVector{length(x)}))
241+
w = normalize(rand(rng, SVector{length(x)}))
242+
xh = make_hyper(x, v, w)
243+
244+
P = evaluate(basis, x)
245+
P_ed, G = evaluate_ed(basis, x)
246+
247+
Ph = evaluate(basis, xh)
248+
@test isapprox(_val(Ph), P; rtol=rtol, atol=atol)
249+
@test isapprox(_e1(Ph), _dir(G, v); rtol=rtol, atol=atol)
250+
@test isapprox(_e2(Ph), _dir(G, w); rtol=rtol, atol=atol)
251+
252+
Pp = evaluate(basis, x, ps, st)
253+
Pp_ed, Gp = evaluate_ed(basis, x, ps, st)
254+
255+
Php = evaluate(basis, xh, ps, st)
256+
@test isapprox(_val(Php), Pp; rtol=rtol, atol=atol)
257+
@test isapprox(_e1(Php), _dir(Gp, v); rtol=rtol, atol=atol)
258+
@test isapprox(_e2(Php), _dir(Gp, w); rtol=rtol, atol=atol)
259+
260+
@test isapprox(P_ed, P; rtol=rtol, atol=atol)
261+
@test isapprox(Pp_ed, Pp; rtol=rtol, atol=atol)
262+
end
216263
end

src/utils/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module Utils
22

33
include("sparse.jl")
4-
# include("hyper.jl")
4+
5+
include("hyper.jl")
56

67
include("linl.jl")
78

test/test_atorbrad.jl

Lines changed: 26 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,34 @@
1-
using LinearAlgebra, StaticArrays, Test, Printf, Polynomials4ML
2-
using Polynomials4ML: evaluate, evaluate_ed
3-
using Polynomials4ML.Testing: print_tf, println_slim
4-
using ForwardDiff
5-
using ACEbase.Testing: fdtest
6-
1+
using Test
2+
using StaticArrays
3+
using LinearAlgebra
4+
using LuxCore
75
import Polynomials4ML as P4ML
6+
using Polynomials4ML: evaluate, evaluate_ed
87

9-
##
10-
11-
@info("Testing GaussianBasis")
12-
basis = P4ML._rand_gaussian_basis()
13-
14-
@info(" correctness of evaluation")
15-
x = P4ML._generate_input(basis)
16-
P = evaluate(basis, x)
17-
P1, dP1 = evaluate_ed(basis, x)
18-
19-
P4ML.Testing.test_evaluate_xx(basis)
20-
P4ML.Testing.test_chainrules(basis)
21-
22-
# Test is broken - reshape is causing this, hence single-input test is turned off
23-
P4ML.Testing.test_withalloc(basis; allowed_allocs = 0, single=false)
24-
25-
# Still fails with single=true, but btime doesn't show the allocation
26-
# even stranger, re-evaluating `_reshape` into P4ML also makes the allocation
27-
# disappear. Probably best to drop this for now, and revisit in a few months.
28-
# P4ML.Testing.test_withalloc(basis; allowed_allocs = 0, single=true)
29-
30-
##
31-
# these are scripts to replicate and check this allocation problem.
32-
# strangely it doesn't occur for the other bases. Only for AtomicOrbtials.
33-
#
34-
# using BenchmarkTools
35-
36-
# P, dP = evaluate_ed(basis, x)
37-
# @btime P4ML.evaluate_ed!($P, $dP, $basis, $x)
38-
39-
# @profview let basis=basis, X=x, P=P, dP=dP
40-
# for _ = 1:1_000_000
41-
# P4ML.evaluate_ed!(P, dP, basis, X)
42-
# end
43-
# end
44-
45-
##
46-
47-
@info("Testing SlaterBasis")
48-
basis = P4ML._rand_slater_basis()
49-
50-
@info(" correctness of evaluation")
51-
x = P4ML._generate_input(basis)
52-
P = evaluate(basis, x)
53-
P1, dP1 = evaluate_ed(basis, x)
548

55-
P4ML.Testing.test_evaluate_xx(basis)
56-
P4ML.Testing.test_chainrules(basis)
57-
P4ML.Testing.test_withalloc(basis; allowed_allocs = 0, single=false)
9+
@testset "GaussianBasis + HyperDual matches evaluate/evaluate_ed" begin
10+
basis = P4ML._rand_gaussian_basis()
11+
P4ML.Testing.test_hyperdual_consistency(basis)
12+
P4ML.Testing.test_evaluate_xx(basis)
13+
P4ML.Testing.test_chainrules(basis)
14+
P4ML.Testing.test_withalloc(basis; allowed_allocs = 0, single=false)
15+
end
5816

59-
##
17+
@testset "SlaterBasis + HyperDual matches evaluate/evaluate_ed" begin
18+
basis = P4ML._rand_slater_basis()
19+
P4ML.Testing.test_hyperdual_consistency(basis)
20+
P4ML.Testing.test_evaluate_xx(basis)
21+
P4ML.Testing.test_chainrules(basis)
22+
P4ML.Testing.test_withalloc(basis; allowed_allocs = 0, single=false)
23+
end
6024

61-
@info("Testing STOBasis")
62-
basis = P4ML._rand_sto_basis()
25+
@testset "STOBasis + HyperDual matches evaluate/evaluate_ed" begin
26+
basis = P4ML._rand_sto_basis()
27+
P4ML.Testing.test_hyperdual_consistency(basis)
28+
P4ML.Testing.test_evaluate_xx(basis)
29+
P4ML.Testing.test_chainrules(basis)
30+
P4ML.Testing.test_withalloc(basis; allowed_allocs = 0, single=false)
31+
end
6332

64-
@info(" correctness of evaluation")
65-
x = P4ML._generate_input(basis)
66-
P = evaluate(basis, x)
67-
P1, dP1 = evaluate_ed(basis, x)
6833

6934

70-
P4ML.Testing.test_evaluate_xx(basis)
71-
P4ML.Testing.test_chainrules(basis)
72-
P4ML.Testing.test_withalloc(basis; allowed_allocs = 0, single=false)

0 commit comments

Comments
 (0)