1+
2+ import Interpolations
3+ import ForwardDiff
4+
5+ # TODO :
6+ # - consider allowing a coordinate transformation before
7+ # applying the splines and a post-multiplication with
8+ # an envelope. That's a bit specific for a general purpose
9+ # libarary but might be very useful for performance and
10+ # and simplicity in ACE applications.
11+ # But it might be better to implement these things as wrappers
12+ # around an arbitrary basis.
13+ #
14+ # - the splines should inherit the specs from the original basis
15+ # they interpolated.
16+
17+
18+ """
19+ `struct CubicSplines`:
20+
21+ Statically typed cubic splines, compatible with P4ML type batched evaluation.
22+ For any P4ML basis with univariate input.
23+ """
24+ struct CubicSplines{NX, NU, T} <: AbstractP4MLBasis
25+ F:: SVector{NX, SVector{NU, T}} # function values at nodes
26+ G:: SVector{NX, SVector{NU, T}} # gradient values at nodes
27+ x0:: T # left endpoint
28+ x1:: T # right endpoint
29+ end
30+
31+
32+ function Base. show (io:: IO , l:: CubicSplines{NX, NU, T} ) where {NX, NU, T}
33+ print (io, " CubicSplines(nx = $NX , len = $NU )" )
34+ end
35+
36+ Base. length (basis:: CubicSplines{NX, NU} ) where {NX, NU} = NU
37+
38+ __NX (basis:: CubicSplines{NX} ) where {NX} = NX
39+
40+ # TODO : this is wrong, instead should inherit from the original basis
41+ natural_indices (basis:: CubicSplines ) = [ (n = n,) for n = 1 : length (basis) ]
42+
43+ _valtype (basis:: CubicSplines{NX, NU, T1} , T2:: Type{<:Real}
44+ ) where {NX, NU, T1} = promote_type (T1, T2)
45+
46+ _valtype (basis:: CubicSplines{NX, NU, T1} , T2:: Type{<:Real} ,
47+ ps, st) where {NX, NU, T1} = promote_type (T2, eltype (eltype (st. F[1 ])))
48+
49+ _generate_input (basis:: CubicSplines ) =
50+ rand () * (basis. x1 - basis. x0) + basis. x0
51+
52+ _init_luxstate (l:: CubicSplines ) =
53+ (F = l. F, G = l. G, x0 = l. x0, x1 = l. x1)
54+
55+
56+ # ----------------- constructor of spline basis
57+
58+ """
59+ splinify(basis, x0, x1, NX; bspline=true)
60+
61+ Takes a P4ML basis with univariate input and constructs a cubic spline basis
62+ that interpolates the basis functions on a uniform grid with `NX` nodes.
63+ If `bspline=true` (default) the function values are first interpolated onto
64+ a B-spline representation to obtain C2,2 regularity of the splines.
65+
66+ `x0`, `x1` are the left and right endpoints of the spline interval.
67+
68+ This is currently not exported and not part of the public interface. The
69+ interface can change in future releases.
70+ """
71+ function splinify (f, x0, x1, NX; bspline= true )
72+ NU = length (f (x0))
73+ xx = range (x0, x1; length= NX)
74+ F = [ SVector {NU} (f (x)) for x in xx ]
75+
76+ if bspline
77+ # CubicSplines represents the splines in terms of piecewise cubics
78+ # specified through values and gradients at the nodes this only gives C1,1
79+ # regularity. By first interpolating onto a B-spline, and then the B-spline
80+ # onto the CubicSplines representation we get C2,2 regularity.
81+ itp = Interpolations. cubic_spline_interpolation (xx, F;
82+ extrapolation_bc= Interpolations. Flat ())
83+ G = [ Interpolations. gradient (itp, x)[1 ] for x in xx ]
84+ else
85+ G = [ FD. derivative (f, x) for x in xx ]
86+ end
87+ T = eltype (eltype (F))
88+ stF = SVector {NX} (F)
89+ stG = SVector {NX} ((SVector{NU, T}). (G))
90+ return CubicSplines (stF, stG, x0, x1)
91+ end
92+
93+
94+
95+ # ----------------- shared evaluation code
96+
97+ """
98+ _eval_cubic(t, fl, fr, gl, gr, h)
99+
100+ Evaluate cubic spline at position `t` in `[0,1]`, given function values `fl`, `fr`
101+ and gradients `gl`, `gr` at the left and right endpoints.
102+ """
103+ @inline function _eval_cubic (t, fl, fr, gl, gr)
104+ # (2t³ - 3t² + 1)*fl + (t³ - 2t² + t)*gl +
105+ # (-2t³ + 3t²)*fr + (t³ - t²)*gr
106+ a0 = fl
107+ a1 = gl
108+ a2 = - 3 fl + 3 fr - 2 gl - gr
109+ a3 = 2 fl - 2 fr + gl + gr
110+ return ((a3* t + a2)* t + a1)* t + a0
111+ end
112+
113+ """
114+ _eval_cubspl(x, F, G, x0, x1, NX)
115+
116+ auxiliary function to the evaluate the cubic spline basis given
117+ the spline data arrays
118+ """
119+ @inline function _eval_cubspl (x, F, G, x0, x1, NX)
120+ x = clamp (x, x0, x1) # project to [x0, x1] (corresponds to Flat bc)
121+ h = (x1 - x0) / (NX- 1 ) # uniform grid spacing
122+ il = floor (Int, (x - x0) / h) # index of left node
123+ # TODO : is this numerically stable?
124+ t = (x - x0) / h - il # relative coordinate of x in [il, il+1]
125+ @inbounds _eval_cubic (t, F[il+ 1 ], F[il+ 2 ], h* G[il+ 1 ], h* G[il+ 2 ])
126+ end
127+
128+ @inline function _cubspl_widthgrad (x, F, G, x0, x1, NX)
129+ if x < x0 || x > x1
130+ f = _eval_cubspl (x, F, G, x0, x1, NX)
131+ return f, zero (f)
132+ end
133+ h = (x1 - x0) / (NX- 1 ) # uniform grid spacing
134+ t, _il = modf ((x - x0) / h)
135+ il = Int (_il)
136+ td = Dual (t, one (t))
137+ fd = _eval_cubic (td, F[il+ 1 ], F[il+ 2 ], h* G[il+ 1 ], h* G[il+ 2 ])
138+ f = ForwardDiff. value .(fd)
139+ g = ForwardDiff. partials .(fd, 1 )
140+ return f, g / h
141+ end
142+
143+
144+ # ----------------- CPU evaluation code
145+
146+ _evaluate! (P, dP, basis:: CubicSplines , X) =
147+ _evaluate! (P, dP, basis, X, nothing , _init_luxstate (basis))
148+
149+
150+ function _evaluate! (P:: AbstractMatrix , dP:: Nothing , basis:: CubicSplines , X:: BATCH , ps, st)
151+ @assert size (P, 1 ) >= length (X)
152+ @inbounds for (i, x) in enumerate (X)
153+ P[i, :] = _eval_cubspl (x, st. F, st. G, st. x0, st. x1, __NX (basis))
154+ end
155+ return nothing
156+ end
157+
158+
159+ function _evaluate! (P:: AbstractMatrix , dP:: AbstractMatrix ,
160+ basis:: CubicSplines , X:: BATCH , ps, st)
161+ @assert size (P, 1 ) >= length (X)
162+ @assert size (dP, 1 ) >= length (X)
163+ @inbounds for (i, x) in enumerate (X)
164+ f, g = _cubspl_widthgrad (x, st. F, st. G, st. x0, st. x1, __NX (basis))
165+ P[i, :] = f
166+ dP[i, :] = g
167+ end
168+ return nothing
169+ end
170+
171+
172+ # ----------------- KernelAbstractions evaluation code
173+
174+
175+ @kernel function _ka_evaluate! (P, dP, basis:: CubicSplines , x:: AbstractVector{T}
176+ ) where {T}
177+
178+ i = @index (Global)
179+
180+ if isnothing (dP)
181+ P[i, :] = _eval_cubspl (x[i], basis. F, basis. G, basis. x0, basis. x1, __NX (basis))
182+ else
183+ f, g = _cubspl_widthgrad (x[i], basis. F, basis. G, basis. x0, basis. x1, __NX (basis))
184+ P[i, :] = f
185+ dP[i, :] = g
186+ end
187+
188+ nothing
189+ end
0 commit comments