Skip to content

Commit 3312c3a

Browse files
authored
Merge pull request #110 from ACEsuit/lux
Make all bases Lux layers
2 parents ed7711a + 44125ff commit 3312c3a

13 files changed

Lines changed: 73 additions & 115 deletions

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2525
WithAlloc = "fb1aa66a-603c-4c1d-9bc4-66947c7b08dd"
2626

2727
[compat]
28-
ACEbase = "0.4.2"
28+
ACEbase = "0.4.5"
2929
BenchmarkTools = "1"
3030
Bumper = "0.7.0"
3131
ChainRulesCore = "1"

src/Polynomials4ML.jl

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,35 @@ module Polynomials4ML
22

33
# -------------- import ACEbase, Bumper, WithAlloc, Lux and related
44

5-
import ACEbase
5+
import ACEbase
6+
67
import ACEbase: evaluate, evaluate_d, evaluate_ed,
7-
evaluate!, evaluate_d!, evaluate_ed!
8+
evaluate!, evaluate_d!, evaluate_ed!,
9+
pullback, pullback!, pushforward, pushforward!
10+
11+
import ChainRulesCore: rrule, frule, NoTangent, ZeroTangent
12+
import LuxCore: AbstractLuxLayer, initialparameters, initialstates
813

914
using Bumper, WithAlloc
1015
import WithAlloc: whatalloc
1116

1217
using KernelAbstractions, GPUArraysCore
1318

14-
using LuxCore, Random, StaticArrays
15-
import ChainRulesCore: rrule, frule, NoTangent, ZeroTangent
19+
using LuxCore, Random, StaticArrays, ChainRulesCore
1620
using ForwardDiff: Dual, extract_derivative
1721
using StaticArrays
18-
import LuxCore: AbstractLuxLayer, initialparameters, initialstates
22+
1923

2024
using Random: AbstractRNG
2125

2226
"""
2327
`abstract type AbstractP4MLBasis end`
2428
2529
Annotates types that map a low-dimensional input, scalar or `SVector`,
26-
to a vector of scalars (feature vector, embedding, basis...).
30+
to a vector of scalars (feature vector, embedding, basis...). Can be used
31+
as a `Lux` layer.
2732
"""
28-
abstract type AbstractP4MLBasis end
33+
abstract type AbstractP4MLBasis <: AbstractLuxLayer end
2934

3035

3136
"""
@@ -64,16 +69,6 @@ Chebyshev polynomials, `index(basis, n)` returns `n+1`.
6469
function index end
6570

6671
function orthpolybasis end
67-
function degree end
68-
69-
function pullback! end
70-
function pullback end
71-
function pushforward end
72-
function pushforward! end
73-
74-
# some stuff to allow bases to overload some lux functionality ...
75-
# how much of this should go into ACEbase?
76-
function lux end
7772

7873
export orthpolybasis
7974

@@ -115,9 +110,6 @@ include("atomicorbitals/atomicorbitals.jl")
115110
# RETIRE - to be discussed?
116111
# include("linear.jl")
117112

118-
# generic machinery for wrapping poly4ml bases into lux layers
119-
include("lux.jl")
120-
121113
# some nice utility functions to generate basis sets and other things
122114
include("utils/utils.jl")
123115

src/atomicorbitals/atomicorbitals.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ function pullback_ps(∂Rnl, basis::AtomicOrbitals, X::AbstractVector{<: SVector
139139
map!(norm, R, X)
140140

141141
# Rnl = output of evaluate(basis, X, ...)
142-
Pn = evaluate(basis.Pn, X, ps.Pn, st.Pn)
143-
Dn = evaluate(basis.Dn, X, ps.Dn, st.Dn)
142+
Pn = evaluate(basis.Pn, R, ps.Pn, st.Pn)
143+
Dn = evaluate(basis.Dn, R, ps.Dn, st.Dn)
144144
Ylm = evaluate(basis.Ylm, X, ps.Ylm, st.Ylm)
145145
∂Pn = zeros(T, size(Pn))
146146
∂Dn = zeros(T, size(Dn))
@@ -156,8 +156,8 @@ function pullback_ps(∂Rnl, basis::AtomicOrbitals, X::AbstractVector{<: SVector
156156
end
157157
end
158158

159-
∂p_Pn = pullback_ps(∂Pn, basis.Pn, X, ps.Pn, st.Pn)
160-
∂p_Dn = pullback_ps(∂Dn, basis.Dn, X, ps.Dn, st.Dn)
159+
∂p_Pn = pullback_ps(∂Pn, basis.Pn, R, ps.Pn, st.Pn)
160+
∂p_Dn = pullback_ps(∂Dn, basis.Dn, R, ps.Dn, st.Dn)
161161
∂p_Ylm = pullback_ps(∂Ylm, basis.Ylm, X, ps.Ylm, st.Ylm)
162162
return (Pn = ∂p_Pn, Dn = ∂p_Dn, Ylm = ∂p_Ylm)
163163
end

src/bernstein.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ _generate_input(basis::BernsteinBasis) = rand()
1818
end
1919

2020

21-
function _evaluate!(P, dP, basis::BernsteinBasis{N}, x::AbstractVector{<:Real}) where {N}
21+
function _evaluate!(P, dP, basis::BernsteinBasis{N}, x::AbstractVector{<:Real},
22+
ps, st) where {N}
2223

2324
n = N - 1
2425
WITHGRAD = !isnothing(dP)

src/interface.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,26 @@ evaluate_ed(l::AbstractP4MLBasis, args...) =
217217
evaluate_d(l::AbstractP4MLBasis, args...) =
218218
evaluate_ed(l, args...)[2]
219219

220+
221+
# ------------------------------------------------------------
222+
# Lux interface
223+
224+
225+
"""
226+
a fall-back method for `initalparameters` that all AbstractP4MLBasis
227+
should overload
228+
"""
229+
_init_luxparams(rng::AbstractRNG, l::Any) = _init_luxparams(l)
230+
_init_luxparams(l) = NamedTuple()
231+
232+
_init_luxstate(rng::AbstractRNG, l::Any) = _init_luxstate(l)
233+
_init_luxstate(l) = NamedTuple()
234+
235+
initialparameters(rng::AbstractRNG, l::AbstractP4MLBasis) =
236+
_init_luxparams(rng, l)
237+
238+
initialstates(rng::AbstractRNG, l::AbstractP4MLBasis) =
239+
_init_luxstate(rng, l)
240+
241+
(l::AbstractP4MLBasis)(X, ps::NamedTuple, st::NamedTuple) =
242+
evaluate(l, X, ps, st), st

src/lux.jl

Lines changed: 0 additions & 55 deletions
This file was deleted.

src/transformed.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@ _generate_input(basis::TransformedBasis{Nothing}) =
4848
# ---------------------------------------------------------
4949
# Lux stuff
5050

51-
_init_luxparams(rng::AbstractRNG, l::TransformedBasis) =
51+
initialparameters(rng::AbstractRNG, l::TransformedBasis) =
5252
( trans = _init_luxparams(rng, l.trans),
5353
basis = _init_luxparams(rng, l.basis), )
5454

55-
_init_luxstate(rng::AbstractRNG, l::TransformedBasis) =
55+
initialstates(rng::AbstractRNG, l::TransformedBasis) =
5656
( trans = _init_luxparams(rng, l.trans),
5757
basis = _init_luxparams(rng, l.basis), )
5858

@@ -96,7 +96,7 @@ function _evaluate!(P, dP::Nothing,
9696

9797
@no_escape begin
9898
# [1] Stage 1 - transform the inputs
99-
z1 = evaluate(tbasis.trans, x[1], ps.trans, st.trans)
99+
z1 = evaluate(tbasis.trans, x[1], ps.trans, st.trans)
100100
TZ = typeof(z1)
101101
Z = @alloc(TZ, nX)
102102
@inbounds begin

test/runtests.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Polynomials4ML
22
using Test
33

4+
##
45
@testset "Polynomials4ML.jl" begin
56

67
# 1D Polynomials
@@ -19,13 +20,14 @@ using Test
1920

2021
# Misc
2122
@testset "Static Prod" begin include("test_staticprod.jl"); end
22-
# @testset "Lux" begin include("test_lux.jl"); end
2323

2424
# Transformations
2525
@testset "Transformed Basis" begin include("test_transformed.jl"); end
2626

27+
# Test lux interface
28+
@testset "Lux" begin include("test_lux.jl"); end
29+
2730
# TODO: restructure or move??
2831
# @testset "Sparse Product" begin include("test_sparseproduct.jl"); end
2932
# @testset "Linear layer" begin include("test_linear.jl"); end
30-
3133
end

test/test_ctrig.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ basis = CTrigBasis(N)
1414
@info(" correctness")
1515
mm = natural_indices(basis)
1616
for ntest = 1:10
17-
local x
17+
local x, P
1818
x = 2*π * rand()
1919
P = basis(x)
2020
P_ref = [ exp(im * m.n * x) for m in mm ]

test/test_lux.jl

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Polynomials4ML, Test, StaticArrays, LuxCore, Zygote, ForwardDiff
2-
using Polynomials4ML: lux, _generate_input
2+
using Polynomials4ML: _generate_input
33
using Random: default_rng
44
using ACEbase.Testing: println_slim, print_tf
55
using LinearAlgebra: dot, I
@@ -66,16 +66,15 @@ test_bases = [ chebyshev_basis(10),
6666
##
6767

6868
for basis in test_bases
69-
# basis = test_bases[13]
69+
# basis = test_bases[12]
7070
@info("Lux layer test for $(typeof(basis).name.name)")
7171
local B1, B2, x
7272
local ps, st
7373
nX = rand(8:16)
7474
X = [ _generate_input(basis) for _ = 1:nX ]
7575
B1 = basis(X)
76-
l = lux(basis)
77-
ps, st = LuxCore.setup(rng, l)
78-
B2, _ = l(X, ps, st)
76+
ps, st = LuxCore.setup(rng, basis)
77+
B2, _ = basis(X, ps, st)
7978
println_slim(@test B1 == B2)
8079

8180
if !( eltype(eltype(B1)) <: Real)
@@ -84,7 +83,7 @@ for basis in test_bases
8483
end
8584

8685
# evaluate the basis and get the pullback operator
87-
val1, pb1 = Zygote.pullback(LuxCore.apply, l, X, ps, st)
86+
val1, pb1 = Zygote.pullback(LuxCore.apply, basis, X, ps, st)
8887
val2, pb2 = Zygote.pullback(Polynomials4ML.evaluate, basis, X, ps, st)
8988
# evaluate the pullback on a random cotangent
9089
Δ = randn(eltype(val2), size(val2))
@@ -98,7 +97,7 @@ for basis in test_bases
9897
println_slim(@test ∂1[3] == ∂2[3])
9998

10099
# look at gradients with respect to the parameters
101-
_foo = p -> dot(Δ, LuxCore.apply(l, X, p, st)[1])
100+
_foo = p -> dot(Δ, LuxCore.apply(basis, X, p, st)[1])
102101
g1 = Zygote.gradient(_foo, ps)[1]
103102
if sizeof(ps) == 0
104103
println_slim(@test (isnothing(g1) || isempty(g1)))
@@ -110,7 +109,3 @@ for basis in test_bases
110109

111110
end
112111

113-
114-
##
115-
116-
# @info("Test Second-order derivatices with Lux")

0 commit comments

Comments
 (0)