@@ -2,30 +2,35 @@ module Polynomials4ML
22
33# -------------- import ACEbase, Bumper, WithAlloc, Lux and related
44
5- import ACEbase
5+ import ACEbase
6+
67import ACEbase: evaluate, evaluate_d, evaluate_ed,
7- evaluate!, evaluate_d!, evaluate_ed!
8+ evaluate!, evaluate_d!, evaluate_ed!,
9+ pullback, pullback!, pushforward, pushforward!
10+
11+ import ChainRulesCore: rrule, frule, NoTangent, ZeroTangent
12+ import LuxCore: AbstractLuxLayer, initialparameters, initialstates
813
914using Bumper, WithAlloc
1015import WithAlloc: whatalloc
1116
1217using KernelAbstractions, GPUArraysCore
1318
14- using LuxCore, Random, StaticArrays
15- import ChainRulesCore: rrule, frule, NoTangent, ZeroTangent
19+ using LuxCore, Random, StaticArrays, ChainRulesCore
1620using ForwardDiff: Dual, extract_derivative
1721using StaticArrays
18- import LuxCore : AbstractLuxLayer, initialparameters, initialstates
22+
1923
2024using Random: AbstractRNG
2125
2226"""
2327`abstract type AbstractP4MLBasis end`
2428
2529Annotates types that map a low-dimensional input, scalar or `SVector`,
26- to a vector of scalars (feature vector, embedding, basis...).
30+ to a vector of scalars (feature vector, embedding, basis...). Can be used
31+ as a `Lux` layer.
2732"""
28- abstract type AbstractP4MLBasis end
33+ abstract type AbstractP4MLBasis <: AbstractLuxLayer end
2934
3035
3136"""
@@ -64,16 +69,6 @@ Chebyshev polynomials, `index(basis, n)` returns `n+1`.
6469function index end
6570
6671function orthpolybasis end
67- function degree end
68-
69- function pullback! end
70- function pullback end
71- function pushforward end
72- function pushforward! end
73-
74- # some stuff to allow bases to overload some lux functionality ...
75- # how much of this should go into ACEbase?
76- function lux end
7772
7873export orthpolybasis
7974
@@ -115,9 +110,6 @@ include("atomicorbitals/atomicorbitals.jl")
115110# RETIRE - to be discussed?
116111# include("linear.jl")
117112
118- # generic machinery for wrapping poly4ml bases into lux layers
119- include (" lux.jl" )
120-
121113# some nice utility functions to generate basis sets and other things
122114include (" utils/utils.jl" )
123115
0 commit comments