Skip to content

Commit e468b64

Browse files
levinedsmshuaibiiwood-b
authored
Quaternion-based Wigner D method for fixing y-aligned edges (#1771)
Implement quaternion-based rotations to allow for smooth gradients and Hessians for edge that are y-aligned (or nearly so). The SO2 convolutions method requires first rotating edges to align with the +Y axis. In the previous Euler angles implementation of that code, edges that already lay on the Y-axis or very near it suffered from a gimbal lock singularity at that point, causing gradients (and higher order derivatives) to become numerically unstable. This manifest in a number of ways including breaking equivariance in specific systems and introducing spurious forces. We adopt an approach based on quaternions which does not suffer from the gimbal lock. Due to the Hairy Ball Theorem, we must have a pole or discontinuity somewhere on the unit sphere. We create to charts: one which has a singularity at the -Y axis and the other at the +Y axis on the unit sphere. We use exclusively one chart or the other (which ever doesn't have a pole) in the vicinity of the Y axes (ey = +/- 1). For ey between -0.9 and +0.9, we then linearly interpolate the quaternion that each chart gives which rotates the edge to the Y axis. This linear interpolation also always rotates edges to the +Y-axis and shifts the singularities at the poles to a discontinuity in the Wigner D matrix in the roll angle about Y-axis. However, since we both already randomize the roll angle and the convolution is invariant to roll, this has no effect on the model. The discontinuity is numerical well-behaved so it introduces no numerical issues; autograd in the vicinity of the border region also follows only one path and therefore does not cause an issue. Conventions for the Wigner D are chosen to exactly match the previous Euler angle convention so that these can be used without retraining on previous checkpoints. --------- Co-authored-by: Muhammed Shuaibi <45150244+mshuaibii@users.noreply.github.com> Co-authored-by: wood-b <bmwood@meta.com>
1 parent 08f7b0b commit e468b64

10 files changed

Lines changed: 3106 additions & 39 deletions

File tree

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
"""
2+
Copyright (c) Meta Platforms, Inc. and affiliates.
3+
4+
This source code is licensed under the MIT license found in the
5+
LICENSE file in the root directory of this source tree.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import torch
11+
12+
# =============================================================================
13+
# Constants
14+
# =============================================================================
15+
16+
# Blend region parameters for two-chart quaternion computation
17+
# The blend region is ey in [BLEND_START, BLEND_START + BLEND_WIDTH]
18+
# which corresponds to ey in [-0.9, 0.9]
19+
BLEND_START = -0.9
20+
BLEND_WIDTH = 1.8
21+
22+
23+
# =============================================================================
24+
# Core Helper Functions
25+
# =============================================================================
26+
27+
28+
def _smooth_step_cinf(t: torch.Tensor) -> torch.Tensor:
29+
"""
30+
C-infinity smooth step function based on the classic bump function.
31+
32+
Uses f(x) = exp(-1/x) for x > 0 (0 otherwise), then:
33+
step(t) = f(t) / (f(t) + f(1-t)) = sigmoid((2t-1)/(t*(1-t)))
34+
35+
Properties:
36+
- C-infinity smooth everywhere
37+
- All derivatives are exactly zero at t=0 and t=1
38+
- Values: f(0)=0, f(1)=1
39+
- Symmetric: f(t) + f(1-t) = 1
40+
41+
Args:
42+
t: Input tensor, will be clamped to [0, 1]
43+
44+
Returns:
45+
Smooth step values in [0, 1]
46+
"""
47+
t_clamped = t.clamp(0, 1)
48+
eps = torch.finfo(t.dtype).eps
49+
50+
numerator = 2.0 * t_clamped - 1.0
51+
denominator = t_clamped * (1.0 - t_clamped)
52+
denom_safe = denominator.clamp(min=eps)
53+
arg = numerator / denom_safe
54+
result = torch.sigmoid(arg)
55+
56+
result = torch.where(t_clamped < eps, torch.zeros_like(result), result)
57+
result = torch.where(t_clamped > 1 - eps, torch.ones_like(result), result)
58+
59+
return result
60+
61+
62+
# =============================================================================
63+
# Quaternion Operations
64+
# =============================================================================
65+
66+
67+
def quaternion_multiply(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
68+
"""
69+
Multiply two quaternions: q1 * q2.
70+
71+
Uses Hamilton product convention: (w, x, y, z).
72+
73+
Args:
74+
q1: First quaternion of shape (N, 4) or (4,)
75+
q2: Second quaternion of shape (N, 4) or (4,)
76+
77+
Returns:
78+
Product quaternion of shape (N, 4)
79+
"""
80+
w1, x1, y1, z1 = q1[..., 0], q1[..., 1], q1[..., 2], q1[..., 3]
81+
w2, x2, y2, z2 = q2[..., 0], q2[..., 1], q2[..., 2], q2[..., 3]
82+
83+
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
84+
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
85+
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
86+
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
87+
88+
return torch.stack([w, x, y, z], dim=-1)
89+
90+
91+
def quaternion_y_rotation(gamma: torch.Tensor) -> torch.Tensor:
92+
"""
93+
Create quaternion for rotation about Y-axis by angle gamma.
94+
95+
Args:
96+
gamma: Rotation angles of shape (N,)
97+
98+
Returns:
99+
Quaternions of shape (N, 4) in (w, x, y, z) convention
100+
"""
101+
half_gamma = gamma / 2
102+
w = torch.cos(half_gamma)
103+
x = torch.zeros_like(gamma)
104+
y = torch.sin(half_gamma)
105+
z = torch.zeros_like(gamma)
106+
return torch.stack([w, x, y, z], dim=-1)
107+
108+
109+
def quaternion_nlerp(
110+
q1: torch.Tensor,
111+
q2: torch.Tensor,
112+
t: torch.Tensor,
113+
) -> torch.Tensor:
114+
"""
115+
Normalized linear interpolation between quaternions.
116+
117+
nlerp(q1, q2, t) = normalize((1-t) * q1 + t * q2)
118+
119+
Args:
120+
q1: First quaternion, shape (..., 4)
121+
q2: Second quaternion, shape (..., 4)
122+
t: Interpolation parameter, shape (...)
123+
124+
Returns:
125+
Interpolated quaternion, shape (..., 4)
126+
"""
127+
dot = (q1 * q2).sum(dim=-1, keepdim=True)
128+
q1_aligned = torch.where(dot < 0, -q1, q1)
129+
130+
t_expanded = t.unsqueeze(-1) if t.dim() < q1.dim() else t
131+
result = torch.nn.functional.normalize(
132+
(1.0 - t_expanded) * q1_aligned + t_expanded * q2, dim=-1
133+
)
134+
135+
return result
136+
137+
138+
# =============================================================================
139+
# Two-Chart Quaternion Edge -> +Y
140+
# =============================================================================
141+
142+
143+
def _quaternion_chart1_standard(
144+
ex: torch.Tensor,
145+
ey: torch.Tensor,
146+
ez: torch.Tensor,
147+
) -> torch.Tensor:
148+
"""
149+
Standard quaternion: edge -> +Y directly. Singular at edge = -Y.
150+
151+
Uses the half-vector formula:
152+
q = normalize(1 + ey, -ez, 0, ex)
153+
154+
Args:
155+
ex, ey, ez: Edge vector components
156+
157+
Returns:
158+
Quaternions of shape (..., 4) in (w, x, y, z) convention
159+
"""
160+
w = 1.0 + ey
161+
x = -ez
162+
y = torch.zeros_like(ex)
163+
z = ex
164+
165+
q = torch.stack([w, x, y, z], dim=-1)
166+
q_sq = torch.sum(q**2, dim=-1, keepdim=True)
167+
eps = torch.finfo(ex.dtype).eps
168+
# q_sq -> 0 at this chart's singularity (ey = -1), but this chart is
169+
# unused there so we don't see the divide by zero. The clamp detaches
170+
# the gradients so that NaNs don't flow through the backward pass.
171+
norm = torch.sqrt(torch.clamp(q_sq, min=eps))
172+
173+
return q / norm
174+
175+
176+
def _quaternion_chart2_via_minus_y(
177+
ex: torch.Tensor,
178+
ey: torch.Tensor,
179+
ez: torch.Tensor,
180+
) -> torch.Tensor:
181+
"""
182+
Alternative quaternion: edge -> +Y via -Y. Singular at edge = +Y.
183+
184+
Path: edge -> -Y -> +Y (compose with 180 deg about X)
185+
186+
Args:
187+
ex, ey, ez: Edge vector components
188+
189+
Returns:
190+
Quaternions of shape (..., 4) in (w, x, y, z) convention
191+
"""
192+
w = -ez
193+
x = 1.0 - ey
194+
y = ex
195+
z = torch.zeros_like(ex)
196+
197+
q = torch.stack([w, x, y, z], dim=-1)
198+
q_sq = torch.sum(q**2, dim=-1, keepdim=True)
199+
eps = torch.finfo(ex.dtype).eps
200+
# q_sq -> 0 at this chart's singularity (ey = +1), but this chart is
201+
# unused there so we don't see the divide by zero. The clamp detaches
202+
# the gradients so that NaNs don't flow through the backward pass.
203+
norm = torch.sqrt(torch.clamp(q_sq, min=eps))
204+
205+
return q / norm
206+
207+
208+
def quaternion_edge_to_y_stable(edge_vec: torch.Tensor) -> torch.Tensor:
209+
"""
210+
Compute quaternion for edge -> +Y using two charts with NLERP blending.
211+
212+
Uses two quaternion charts to avoid singularities:
213+
- Chart 1: q = normalize(1+ey, -ez, 0, ex) - singular at -Y
214+
- Chart 2: q = normalize(-ez, 1-ey, ex, 0) - singular at +Y
215+
216+
NLERP blend in ey in [-0.9, 0.9]:
217+
- Uses Chart 2 when near -Y (stable there)
218+
- Uses Chart 1 when near +Y (stable there)
219+
- Smoothly interpolates in between
220+
221+
Args:
222+
edge_vec: Edge vectors of shape (N, 3), assumed normalized
223+
224+
Returns:
225+
Quaternions of shape (N, 4) in (w, x, y, z) convention
226+
"""
227+
ex = edge_vec[..., 0]
228+
ey = edge_vec[..., 1]
229+
ez = edge_vec[..., 2]
230+
231+
q_chart1 = _quaternion_chart1_standard(ex, ey, ez)
232+
q_chart2 = _quaternion_chart2_via_minus_y(ex, ey, ez)
233+
234+
t = (ey - BLEND_START) / BLEND_WIDTH
235+
t_smooth = _smooth_step_cinf(t)
236+
237+
q = quaternion_nlerp(q_chart2, q_chart1, t_smooth)
238+
239+
return q
240+
241+
242+
# =============================================================================
243+
# Gamma Computation for Euler Matching
244+
# =============================================================================

0 commit comments

Comments
 (0)