Skip to content

Commit e7e780b

Browse files
committed
lint
1 parent 682de7b commit e7e780b

2 files changed

Lines changed: 215 additions & 53 deletions

File tree

src/fairchem/core/models/uma/nn/execution_backends.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414

1515
import torch
1616

17+
from fairchem.core.models.uma.nn.unified_radial import UnifiedRadialMLP
18+
19+
if TYPE_CHECKING:
20+
from fairchem.core.units.mlip_unit.api.inference import (
21+
InferenceSettings,
22+
)
23+
1724
# Enable expandable segments for the CUDA caching allocator to reduce
1825
# memory fragmentation and eliminate periodic GC stalls during inference.
1926
# Must be set before the first CUDA allocation.
@@ -25,13 +32,6 @@
2532
# Enable aggressive fusion of inductor ops
2633
torch._inductor.config.aggressive_fusion = True
2734

28-
from fairchem.core.models.uma.nn.unified_radial import UnifiedRadialMLP
29-
30-
if TYPE_CHECKING:
31-
from fairchem.core.units.mlip_unit.api.inference import (
32-
InferenceSettings,
33-
)
34-
3535
__all__ = [
3636
"ExecutionMode",
3737
"ExecutionBackend",

src/fairchem/core/models/uma/triton/kernels.py

Lines changed: 208 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -924,47 +924,191 @@ def node_to_edge_wigner_permute_bwd_dw_kernel(
924924
# M_TO_L_GATHER_IDX = [0, 5, 1, 3, 8, 6, 2, 4, 7]
925925
# Each dy has 2C channels: first C = src, last C = tgt
926926
# Load src part (first C channels)
927-
dy0s = tl.load(grad_out_ptr + grad_base + 0 * sphere_channels * 2 + c_range, mask=c_mask, other=0.0)
928-
dy1s = tl.load(grad_out_ptr + grad_base + 5 * sphere_channels * 2 + c_range, mask=c_mask, other=0.0)
929-
dy2s = tl.load(grad_out_ptr + grad_base + 1 * sphere_channels * 2 + c_range, mask=c_mask, other=0.0)
930-
dy3s = tl.load(grad_out_ptr + grad_base + 3 * sphere_channels * 2 + c_range, mask=c_mask, other=0.0)
931-
dy4s = tl.load(grad_out_ptr + grad_base + 8 * sphere_channels * 2 + c_range, mask=c_mask, other=0.0)
932-
dy5s = tl.load(grad_out_ptr + grad_base + 6 * sphere_channels * 2 + c_range, mask=c_mask, other=0.0)
933-
dy6s = tl.load(grad_out_ptr + grad_base + 2 * sphere_channels * 2 + c_range, mask=c_mask, other=0.0)
934-
dy7s = tl.load(grad_out_ptr + grad_base + 4 * sphere_channels * 2 + c_range, mask=c_mask, other=0.0)
935-
dy8s = tl.load(grad_out_ptr + grad_base + 7 * sphere_channels * 2 + c_range, mask=c_mask, other=0.0)
927+
dy0s = tl.load(
928+
grad_out_ptr + grad_base + 0 * sphere_channels * 2 + c_range,
929+
mask=c_mask,
930+
other=0.0,
931+
)
932+
dy1s = tl.load(
933+
grad_out_ptr + grad_base + 5 * sphere_channels * 2 + c_range,
934+
mask=c_mask,
935+
other=0.0,
936+
)
937+
dy2s = tl.load(
938+
grad_out_ptr + grad_base + 1 * sphere_channels * 2 + c_range,
939+
mask=c_mask,
940+
other=0.0,
941+
)
942+
dy3s = tl.load(
943+
grad_out_ptr + grad_base + 3 * sphere_channels * 2 + c_range,
944+
mask=c_mask,
945+
other=0.0,
946+
)
947+
dy4s = tl.load(
948+
grad_out_ptr + grad_base + 8 * sphere_channels * 2 + c_range,
949+
mask=c_mask,
950+
other=0.0,
951+
)
952+
dy5s = tl.load(
953+
grad_out_ptr + grad_base + 6 * sphere_channels * 2 + c_range,
954+
mask=c_mask,
955+
other=0.0,
956+
)
957+
dy6s = tl.load(
958+
grad_out_ptr + grad_base + 2 * sphere_channels * 2 + c_range,
959+
mask=c_mask,
960+
other=0.0,
961+
)
962+
dy7s = tl.load(
963+
grad_out_ptr + grad_base + 4 * sphere_channels * 2 + c_range,
964+
mask=c_mask,
965+
other=0.0,
966+
)
967+
dy8s = tl.load(
968+
grad_out_ptr + grad_base + 7 * sphere_channels * 2 + c_range,
969+
mask=c_mask,
970+
other=0.0,
971+
)
936972

937973
# Load tgt part (second C channels, offset by sphere_channels)
938-
dy0t = tl.load(grad_out_ptr + grad_base + 0 * sphere_channels * 2 + sphere_channels + c_range, mask=c_mask, other=0.0)
939-
dy1t = tl.load(grad_out_ptr + grad_base + 5 * sphere_channels * 2 + sphere_channels + c_range, mask=c_mask, other=0.0)
940-
dy2t = tl.load(grad_out_ptr + grad_base + 1 * sphere_channels * 2 + sphere_channels + c_range, mask=c_mask, other=0.0)
941-
dy3t = tl.load(grad_out_ptr + grad_base + 3 * sphere_channels * 2 + sphere_channels + c_range, mask=c_mask, other=0.0)
942-
dy4t = tl.load(grad_out_ptr + grad_base + 8 * sphere_channels * 2 + sphere_channels + c_range, mask=c_mask, other=0.0)
943-
dy5t = tl.load(grad_out_ptr + grad_base + 6 * sphere_channels * 2 + sphere_channels + c_range, mask=c_mask, other=0.0)
944-
dy6t = tl.load(grad_out_ptr + grad_base + 2 * sphere_channels * 2 + sphere_channels + c_range, mask=c_mask, other=0.0)
945-
dy7t = tl.load(grad_out_ptr + grad_base + 4 * sphere_channels * 2 + sphere_channels + c_range, mask=c_mask, other=0.0)
946-
dy8t = tl.load(grad_out_ptr + grad_base + 7 * sphere_channels * 2 + sphere_channels + c_range, mask=c_mask, other=0.0)
974+
dy0t = tl.load(
975+
grad_out_ptr
976+
+ grad_base
977+
+ 0 * sphere_channels * 2
978+
+ sphere_channels
979+
+ c_range,
980+
mask=c_mask,
981+
other=0.0,
982+
)
983+
dy1t = tl.load(
984+
grad_out_ptr
985+
+ grad_base
986+
+ 5 * sphere_channels * 2
987+
+ sphere_channels
988+
+ c_range,
989+
mask=c_mask,
990+
other=0.0,
991+
)
992+
dy2t = tl.load(
993+
grad_out_ptr
994+
+ grad_base
995+
+ 1 * sphere_channels * 2
996+
+ sphere_channels
997+
+ c_range,
998+
mask=c_mask,
999+
other=0.0,
1000+
)
1001+
dy3t = tl.load(
1002+
grad_out_ptr
1003+
+ grad_base
1004+
+ 3 * sphere_channels * 2
1005+
+ sphere_channels
1006+
+ c_range,
1007+
mask=c_mask,
1008+
other=0.0,
1009+
)
1010+
dy4t = tl.load(
1011+
grad_out_ptr
1012+
+ grad_base
1013+
+ 8 * sphere_channels * 2
1014+
+ sphere_channels
1015+
+ c_range,
1016+
mask=c_mask,
1017+
other=0.0,
1018+
)
1019+
dy5t = tl.load(
1020+
grad_out_ptr
1021+
+ grad_base
1022+
+ 6 * sphere_channels * 2
1023+
+ sphere_channels
1024+
+ c_range,
1025+
mask=c_mask,
1026+
other=0.0,
1027+
)
1028+
dy6t = tl.load(
1029+
grad_out_ptr
1030+
+ grad_base
1031+
+ 2 * sphere_channels * 2
1032+
+ sphere_channels
1033+
+ c_range,
1034+
mask=c_mask,
1035+
other=0.0,
1036+
)
1037+
dy7t = tl.load(
1038+
grad_out_ptr
1039+
+ grad_base
1040+
+ 4 * sphere_channels * 2
1041+
+ sphere_channels
1042+
+ c_range,
1043+
mask=c_mask,
1044+
other=0.0,
1045+
)
1046+
dy8t = tl.load(
1047+
grad_out_ptr
1048+
+ grad_base
1049+
+ 7 * sphere_channels * 2
1050+
+ sphere_channels
1051+
+ c_range,
1052+
mask=c_mask,
1053+
other=0.0,
1054+
)
9471055

9481056
# Load node features (L-major order)
949-
xs0 = tl.load(x_ptr + idx0 * x_stride_n + 0 * x_stride_m + c_range, mask=c_mask, other=0.0)
950-
xs1 = tl.load(x_ptr + idx0 * x_stride_n + 1 * x_stride_m + c_range, mask=c_mask, other=0.0)
951-
xs2 = tl.load(x_ptr + idx0 * x_stride_n + 2 * x_stride_m + c_range, mask=c_mask, other=0.0)
952-
xs3 = tl.load(x_ptr + idx0 * x_stride_n + 3 * x_stride_m + c_range, mask=c_mask, other=0.0)
953-
xs4 = tl.load(x_ptr + idx0 * x_stride_n + 4 * x_stride_m + c_range, mask=c_mask, other=0.0)
954-
xs5 = tl.load(x_ptr + idx0 * x_stride_n + 5 * x_stride_m + c_range, mask=c_mask, other=0.0)
955-
xs6 = tl.load(x_ptr + idx0 * x_stride_n + 6 * x_stride_m + c_range, mask=c_mask, other=0.0)
956-
xs7 = tl.load(x_ptr + idx0 * x_stride_n + 7 * x_stride_m + c_range, mask=c_mask, other=0.0)
957-
xs8 = tl.load(x_ptr + idx0 * x_stride_n + 8 * x_stride_m + c_range, mask=c_mask, other=0.0)
958-
959-
xt0 = tl.load(x_ptr + idx1 * x_stride_n + 0 * x_stride_m + c_range, mask=c_mask, other=0.0)
960-
xt1 = tl.load(x_ptr + idx1 * x_stride_n + 1 * x_stride_m + c_range, mask=c_mask, other=0.0)
961-
xt2 = tl.load(x_ptr + idx1 * x_stride_n + 2 * x_stride_m + c_range, mask=c_mask, other=0.0)
962-
xt3 = tl.load(x_ptr + idx1 * x_stride_n + 3 * x_stride_m + c_range, mask=c_mask, other=0.0)
963-
xt4 = tl.load(x_ptr + idx1 * x_stride_n + 4 * x_stride_m + c_range, mask=c_mask, other=0.0)
964-
xt5 = tl.load(x_ptr + idx1 * x_stride_n + 5 * x_stride_m + c_range, mask=c_mask, other=0.0)
965-
xt6 = tl.load(x_ptr + idx1 * x_stride_n + 6 * x_stride_m + c_range, mask=c_mask, other=0.0)
966-
xt7 = tl.load(x_ptr + idx1 * x_stride_n + 7 * x_stride_m + c_range, mask=c_mask, other=0.0)
967-
xt8 = tl.load(x_ptr + idx1 * x_stride_n + 8 * x_stride_m + c_range, mask=c_mask, other=0.0)
1057+
xs0 = tl.load(
1058+
x_ptr + idx0 * x_stride_n + 0 * x_stride_m + c_range, mask=c_mask, other=0.0
1059+
)
1060+
xs1 = tl.load(
1061+
x_ptr + idx0 * x_stride_n + 1 * x_stride_m + c_range, mask=c_mask, other=0.0
1062+
)
1063+
xs2 = tl.load(
1064+
x_ptr + idx0 * x_stride_n + 2 * x_stride_m + c_range, mask=c_mask, other=0.0
1065+
)
1066+
xs3 = tl.load(
1067+
x_ptr + idx0 * x_stride_n + 3 * x_stride_m + c_range, mask=c_mask, other=0.0
1068+
)
1069+
xs4 = tl.load(
1070+
x_ptr + idx0 * x_stride_n + 4 * x_stride_m + c_range, mask=c_mask, other=0.0
1071+
)
1072+
xs5 = tl.load(
1073+
x_ptr + idx0 * x_stride_n + 5 * x_stride_m + c_range, mask=c_mask, other=0.0
1074+
)
1075+
xs6 = tl.load(
1076+
x_ptr + idx0 * x_stride_n + 6 * x_stride_m + c_range, mask=c_mask, other=0.0
1077+
)
1078+
xs7 = tl.load(
1079+
x_ptr + idx0 * x_stride_n + 7 * x_stride_m + c_range, mask=c_mask, other=0.0
1080+
)
1081+
xs8 = tl.load(
1082+
x_ptr + idx0 * x_stride_n + 8 * x_stride_m + c_range, mask=c_mask, other=0.0
1083+
)
1084+
1085+
xt0 = tl.load(
1086+
x_ptr + idx1 * x_stride_n + 0 * x_stride_m + c_range, mask=c_mask, other=0.0
1087+
)
1088+
xt1 = tl.load(
1089+
x_ptr + idx1 * x_stride_n + 1 * x_stride_m + c_range, mask=c_mask, other=0.0
1090+
)
1091+
xt2 = tl.load(
1092+
x_ptr + idx1 * x_stride_n + 2 * x_stride_m + c_range, mask=c_mask, other=0.0
1093+
)
1094+
xt3 = tl.load(
1095+
x_ptr + idx1 * x_stride_n + 3 * x_stride_m + c_range, mask=c_mask, other=0.0
1096+
)
1097+
xt4 = tl.load(
1098+
x_ptr + idx1 * x_stride_n + 4 * x_stride_m + c_range, mask=c_mask, other=0.0
1099+
)
1100+
xt5 = tl.load(
1101+
x_ptr + idx1 * x_stride_n + 5 * x_stride_m + c_range, mask=c_mask, other=0.0
1102+
)
1103+
xt6 = tl.load(
1104+
x_ptr + idx1 * x_stride_n + 6 * x_stride_m + c_range, mask=c_mask, other=0.0
1105+
)
1106+
xt7 = tl.load(
1107+
x_ptr + idx1 * x_stride_n + 7 * x_stride_m + c_range, mask=c_mask, other=0.0
1108+
)
1109+
xt8 = tl.load(
1110+
x_ptr + idx1 * x_stride_n + 8 * x_stride_m + c_range, mask=c_mask, other=0.0
1111+
)
9681112

9691113
# dW[i,j] = sum_c (dy_src[i]*x_src[j] + dy_tgt[i]*x_tgt[j])
9701114
# L=0 block (1x1)
@@ -1355,15 +1499,33 @@ def permute_wigner_inv_edge_to_node_bwd_dw_kernel(
13551499
dy8 = tl.load(DY_ptr + dy_base + 8 * C + c_range, mask=c_mask, other=0.0)
13561500

13571501
# Load x from M-major positions using M_TO_L_GATHER_IDX = [0,5,1,3,8,6,2,4,7]
1358-
x0 = tl.load(X_ptr + x_base + 0 * C + c_range, mask=c_mask, other=0.0) # L=0 <- M=0
1359-
x1 = tl.load(X_ptr + x_base + 5 * C + c_range, mask=c_mask, other=0.0) # L=1 <- M=5
1360-
x2 = tl.load(X_ptr + x_base + 1 * C + c_range, mask=c_mask, other=0.0) # L=2 <- M=1
1361-
x3 = tl.load(X_ptr + x_base + 3 * C + c_range, mask=c_mask, other=0.0) # L=3 <- M=3
1362-
x4 = tl.load(X_ptr + x_base + 8 * C + c_range, mask=c_mask, other=0.0) # L=4 <- M=8
1363-
x5 = tl.load(X_ptr + x_base + 6 * C + c_range, mask=c_mask, other=0.0) # L=5 <- M=6
1364-
x6 = tl.load(X_ptr + x_base + 2 * C + c_range, mask=c_mask, other=0.0) # L=6 <- M=2
1365-
x7 = tl.load(X_ptr + x_base + 4 * C + c_range, mask=c_mask, other=0.0) # L=7 <- M=4
1366-
x8 = tl.load(X_ptr + x_base + 7 * C + c_range, mask=c_mask, other=0.0) # L=8 <- M=7
1502+
x0 = tl.load(
1503+
X_ptr + x_base + 0 * C + c_range, mask=c_mask, other=0.0
1504+
) # L=0 <- M=0
1505+
x1 = tl.load(
1506+
X_ptr + x_base + 5 * C + c_range, mask=c_mask, other=0.0
1507+
) # L=1 <- M=5
1508+
x2 = tl.load(
1509+
X_ptr + x_base + 1 * C + c_range, mask=c_mask, other=0.0
1510+
) # L=2 <- M=1
1511+
x3 = tl.load(
1512+
X_ptr + x_base + 3 * C + c_range, mask=c_mask, other=0.0
1513+
) # L=3 <- M=3
1514+
x4 = tl.load(
1515+
X_ptr + x_base + 8 * C + c_range, mask=c_mask, other=0.0
1516+
) # L=4 <- M=8
1517+
x5 = tl.load(
1518+
X_ptr + x_base + 6 * C + c_range, mask=c_mask, other=0.0
1519+
) # L=5 <- M=6
1520+
x6 = tl.load(
1521+
X_ptr + x_base + 2 * C + c_range, mask=c_mask, other=0.0
1522+
) # L=6 <- M=2
1523+
x7 = tl.load(
1524+
X_ptr + x_base + 4 * C + c_range, mask=c_mask, other=0.0
1525+
) # L=7 <- M=4
1526+
x8 = tl.load(
1527+
X_ptr + x_base + 7 * C + c_range, mask=c_mask, other=0.0
1528+
) # L=8 <- M=7
13671529

13681530
# L=0 block (1x1): dW[0,0] = sum_c dy[0,c] * x[0,c]
13691531
dw_00 = tl.sum(dy0 * x0)

0 commit comments

Comments
 (0)