Skip to content

Commit 9254fa4

Browse files
authored
Merge pull request #119 from ACEsuit/splines
Splinification
2 parents e99fa5c + 6e5aec6 commit 9254fa4

4 files changed

Lines changed: 316 additions & 2 deletions

File tree

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1313
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1414
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1515
HyperDualNumbers = "50ceba7f-c3ee-5a84-a6e8-3ad40456ec97"
16+
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
1617
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1718
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1819
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
@@ -25,7 +26,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2526
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2627
WithAlloc = "fb1aa66a-603c-4c1d-9bc4-66947c7b08dd"
2728

28-
2929
[compat]
3030
ACEbase = "0.4.5"
3131
BenchmarkTools = "1"
@@ -36,6 +36,7 @@ DiffResults = "1.1.0"
3636
ForwardDiff = "0.10, 1.0"
3737
GPUArraysCore = "0.2.0"
3838
HyperDualNumbers = "4.0.10"
39+
Interpolations = "0.16.2"
3940
KernelAbstractions = "0.9.34"
4041
LinearAlgebra = "1.10"
4142
LuxCore = "1.2.0"
@@ -56,4 +57,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5657
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5758

5859
[targets]
59-
test = ["Test", "Printf", "Optimisers", "Zygote", ]
60+
test = ["Test", "Printf", "Optimisers", "Zygote"]

src/Polynomials4ML.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ include("monomials.jl")
9494
include("chebbasis.jl")
9595
include("bernstein.jl")
9696

97+
# splines
98+
include("splinify.jl")
99+
97100
# 2d harmonics / trigonometric polynomials
98101
include("ctrig.jl")
99102
include("rtrig.jl")

src/splinify.jl

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
2+
import Interpolations
3+
import ForwardDiff
4+
5+
# TODO:
6+
# - consider allowing a coordinate transformation before
7+
# applying the splines and a post-multiplication with
8+
# an envelope. That's a bit specific for a general purpose
9+
# libarary but might be very useful for performance and
10+
# and simplicity in ACE applications.
11+
# But it might be better to implement these things as wrappers
12+
# around an arbitrary basis.
13+
#
14+
# - the splines should inherit the specs from the original basis
15+
# they interpolated.
16+
17+
18+
"""
19+
`struct CubicSplines`:
20+
21+
Statically typed cubic splines, compatible with P4ML type batched evaluation.
22+
For any P4ML basis with univariate input.
23+
"""
24+
struct CubicSplines{NX, NU, T} <: AbstractP4MLBasis
25+
F::SVector{NX, SVector{NU, T}} # function values at nodes
26+
G::SVector{NX, SVector{NU, T}} # gradient values at nodes
27+
x0::T # left endpoint
28+
x1::T # right endpoint
29+
end
30+
31+
32+
function Base.show(io::IO, l::CubicSplines{NX, NU, T}) where {NX, NU, T}
33+
print(io, "CubicSplines(nx = $NX, len = $NU)")
34+
end
35+
36+
Base.length(basis::CubicSplines{NX, NU}) where {NX, NU} = NU
37+
38+
__NX(basis::CubicSplines{NX}) where {NX} = NX
39+
40+
# TODO: this is wrong, instead should inherit from the original basis
41+
natural_indices(basis::CubicSplines) = [ (n = n,) for n = 1:length(basis) ]
42+
43+
_valtype(basis::CubicSplines{NX, NU, T1}, T2::Type{<:Real}
44+
) where {NX, NU, T1} = promote_type(T1, T2)
45+
46+
_valtype(basis::CubicSplines{NX, NU, T1}, T2::Type{<:Real},
47+
ps, st) where {NX, NU, T1} = promote_type(T2, eltype(eltype(st.F[1])))
48+
49+
_generate_input(basis::CubicSplines) =
50+
rand() * (basis.x1 - basis.x0) + basis.x0
51+
52+
_init_luxstate(l::CubicSplines) =
53+
(F = l.F, G = l.G, x0 = l.x0, x1 = l.x1)
54+
55+
56+
# ----------------- constructor of spline basis
57+
58+
"""
59+
splinify(basis, x0, x1, NX; bspline=true)
60+
61+
Takes a P4ML basis with univariate input and constructs a cubic spline basis
62+
that interpolates the basis functions on a uniform grid with `NX` nodes.
63+
If `bspline=true` (default) the function values are first interpolated onto
64+
a B-spline representation to obtain C2,2 regularity of the splines.
65+
66+
`x0`, `x1` are the left and right endpoints of the spline interval.
67+
68+
This is currently not exported and not part of the public interface. The
69+
interface can change in future releases.
70+
"""
71+
function splinify(f, x0, x1, NX; bspline=true)
72+
NU = length(f(x0))
73+
xx = range(x0, x1; length=NX)
74+
F = [ SVector{NU}(f(x)) for x in xx ]
75+
76+
if bspline
77+
# CubicSplines represents the splines in terms of piecewise cubics
78+
# specified through values and gradients at the nodes this only gives C1,1
79+
# regularity. By first interpolating onto a B-spline, and then the B-spline
80+
# onto the CubicSplines representation we get C2,2 regularity.
81+
itp = Interpolations.cubic_spline_interpolation(xx, F;
82+
extrapolation_bc=Interpolations.Flat())
83+
G = [ Interpolations.gradient(itp, x)[1] for x in xx ]
84+
else
85+
G = [ FD.derivative(f, x) for x in xx ]
86+
end
87+
T = eltype(eltype(F))
88+
stF = SVector{NX}(F)
89+
stG = SVector{NX}((SVector{NU, T}).(G))
90+
return CubicSplines(stF, stG, x0, x1)
91+
end
92+
93+
94+
95+
# ----------------- shared evaluation code
96+
97+
"""
98+
_eval_cubic(t, fl, fr, gl, gr, h)
99+
100+
Evaluate cubic spline at position `t` in `[0,1]`, given function values `fl`, `fr`
101+
and gradients `gl`, `gr` at the left and right endpoints.
102+
"""
103+
@inline function _eval_cubic(t, fl, fr, gl, gr)
104+
# (2t³ - 3t² + 1)*fl + (t³ - 2t² + t)*gl +
105+
# (-2t³ + 3t²)*fr + (t³ - t²)*gr
106+
a0 = fl
107+
a1 = gl
108+
a2 = -3fl + 3fr - 2gl - gr
109+
a3 = 2fl - 2fr + gl + gr
110+
return ((a3*t + a2)*t + a1)*t + a0
111+
end
112+
113+
"""
114+
_eval_cubspl(x, F, G, x0, x1, NX)
115+
116+
auxiliary function to the evaluate the cubic spline basis given
117+
the spline data arrays
118+
"""
119+
@inline function _eval_cubspl(x, F, G, x0, x1, NX)
120+
x = clamp(x, x0, x1) # project to [x0, x1] (corresponds to Flat bc)
121+
h = (x1 - x0) / (NX-1) # uniform grid spacing
122+
il = floor(Int, (x - x0) / h) # index of left node
123+
# TODO: is this numerically stable?
124+
t = (x - x0) / h - il # relative coordinate of x in [il, il+1]
125+
@inbounds _eval_cubic(t, F[il+1], F[il+2], h*G[il+1], h*G[il+2])
126+
end
127+
128+
@inline function _cubspl_widthgrad(x, F, G, x0, x1, NX)
129+
if x < x0 || x > x1
130+
f = _eval_cubspl(x, F, G, x0, x1, NX)
131+
return f, zero(f)
132+
end
133+
h = (x1 - x0) / (NX-1) # uniform grid spacing
134+
t, _il = modf((x - x0) / h)
135+
il = Int(_il)
136+
td = Dual(t, one(t))
137+
fd = _eval_cubic(td, F[il+1], F[il+2], h*G[il+1], h*G[il+2])
138+
f = ForwardDiff.value.(fd)
139+
g = ForwardDiff.partials.(fd, 1)
140+
return f, g / h
141+
end
142+
143+
144+
# ----------------- CPU evaluation code
145+
146+
_evaluate!(P, dP, basis::CubicSplines, X) =
147+
_evaluate!(P, dP, basis, X, nothing, _init_luxstate(basis))
148+
149+
150+
function _evaluate!(P::AbstractMatrix, dP::Nothing, basis::CubicSplines, X::BATCH, ps, st)
151+
@assert size(P, 1) >= length(X)
152+
@inbounds for (i, x) in enumerate(X)
153+
P[i, :] = _eval_cubspl(x, st.F, st.G, st.x0, st.x1, __NX(basis))
154+
end
155+
return nothing
156+
end
157+
158+
159+
function _evaluate!(P::AbstractMatrix, dP::AbstractMatrix,
160+
basis::CubicSplines, X::BATCH, ps, st)
161+
@assert size(P, 1) >= length(X)
162+
@assert size(dP, 1) >= length(X)
163+
@inbounds for (i, x) in enumerate(X)
164+
f, g = _cubspl_widthgrad(x, st.F, st.G, st.x0, st.x1, __NX(basis))
165+
P[i, :] = f
166+
dP[i, :] = g
167+
end
168+
return nothing
169+
end
170+
171+
172+
# ----------------- KernelAbstractions evaluation code
173+
174+
175+
@kernel function _ka_evaluate!(P, dP, basis::CubicSplines, x::AbstractVector{T}
176+
) where {T}
177+
178+
i = @index(Global)
179+
180+
if isnothing(dP)
181+
P[i, :] = _eval_cubspl(x[i], basis.F, basis.G, basis.x0, basis.x1, __NX(basis))
182+
else
183+
f, g = _cubspl_widthgrad(x[i], basis.F, basis.G, basis.x0, basis.x1, __NX(basis))
184+
P[i, :] = f
185+
dP[i, :] = g
186+
end
187+
188+
nothing
189+
end

test/test_splines.jl

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
2+
3+
4+
using Polynomials4ML, Test, LuxCore, Random
5+
using Polynomials4ML: _generate_input
6+
using Polynomials4ML.Testing: println_slim, print_tf, test_all
7+
using LinearAlgebra: I, norm, dot
8+
using QuadGK
9+
import Polynomials4ML as P4ML
10+
11+
12+
##
13+
14+
# N = rand(5:15)
15+
N = 10
16+
# basis = OrthPolyBasis1D3T(randn(N), randn(N), randn(N))
17+
basis = chebyshev_basis(N)
18+
spec = Polynomials4ML.natural_indices(basis)
19+
20+
##
21+
22+
using StaticArrays
23+
import Interpolations as INT
24+
25+
Nnod = 30
26+
xx = range(-1.0, 1.0; length=Nnod)
27+
PP = [ SVector{length(basis)}(basis(x)) for x in xx ]
28+
29+
spl = INT.cubic_spline_interpolation(xx, PP)
30+
31+
st_spl1 = P4ML.splinify(basis, -1.0, 1.0, Nnod; bspline=true)
32+
st_spl2 = P4ML.splinify(basis, -1.0, 1.0, Nnod; bspline=false)
33+
34+
##
35+
36+
Random.seed!(1234)
37+
38+
for x in xx
39+
P1 = basis(x)
40+
P2 = spl(x)
41+
P3 = st_spl1(x)
42+
print_tf(@test P1 P2 P3)
43+
end
44+
println()
45+
46+
47+
for x in xx[1:end-1]
48+
y = x + rand() * 2/(Nnod-1)
49+
P1 = basis(y)
50+
P2 = spl(y)
51+
P3 = st_spl1(y)
52+
print_tf(@test P2 P3)
53+
scalerr = norm( (P1 - P2) * (1-y^2)^2 ./ (1:N).^3 , Inf)
54+
print_tf(@test scalerr < 1e-5)
55+
end
56+
println()
57+
58+
59+
##
60+
61+
P4ML.Testing.test_all(st_spl1; ka = true)
62+
63+
64+
##
65+
66+
# import ForwardDiff as FD
67+
68+
# st_spl1 = SPL.splinify(basis, -1.0, 1.0, 30)
69+
# st_spl2 = SPL.splinify(spl, -1.0, 0.99999, 30)
70+
71+
# x = 1 - rand()/(Nnod-1)
72+
# P1 = basis(x)
73+
# P2 = spl(x)
74+
# P3 = st_spl1(x)
75+
# P4 = st_spl2(x)
76+
77+
# @show norm(P1 - P2, Inf)
78+
# @show norm(P1 - P3, Inf)
79+
# @show norm(P2 - P3, Inf)
80+
# @show norm(P2 - P4, Inf)
81+
# @show norm(P3 - P4, Inf)
82+
83+
84+
##
85+
86+
# TODO: add benchmaks
87+
88+
# using BenchmarkTools
89+
90+
# function _benchrun(basis, Ntest = 1000)
91+
# s = 0.0
92+
# for ntest = 1:Ntest
93+
# x = rand() * 2 - 1
94+
# P1 = basis(x)
95+
# s += P1[2]
96+
# end
97+
# return s
98+
# end
99+
100+
# function _benchrun_psst(basis, ps, st, Ntest = 1000)
101+
# s = 0.0
102+
# for ntest = 1:Ntest
103+
# x = rand() * 2 - 1
104+
# P1, st = basis(x, ps, st)
105+
# s += P1[2]
106+
# end
107+
# return s
108+
# end
109+
110+
111+
# @btime _benchrun($basis)
112+
# @btime _benchrun($spl)
113+
# @btime _benchrun($st_spl1)
114+
115+
# ps, st = LuxCore.setup(MersenneTwister(1234), st_spl1)
116+
# @btime _benchrun_psst($st_spl1, $ps, $st)
117+
118+
# st_spl4 = P4ML.splinify(basis, -1.0, 1.0, Nnod; bspline=true)
119+
# ps4, st4 = LuxCore.setup(MersenneTwister(1234), st_spl4)
120+
# @btime _benchrun_psst($st_spl4, $ps4, $st4)
121+

0 commit comments

Comments
 (0)