@@ -924,47 +924,191 @@ def node_to_edge_wigner_permute_bwd_dw_kernel(
924924 # M_TO_L_GATHER_IDX = [0, 5, 1, 3, 8, 6, 2, 4, 7]
925925 # Each dy has 2C channels: first C = src, last C = tgt
926926 # Load src part (first C channels)
927- dy0s = tl .load (grad_out_ptr + grad_base + 0 * sphere_channels * 2 + c_range , mask = c_mask , other = 0.0 )
928- dy1s = tl .load (grad_out_ptr + grad_base + 5 * sphere_channels * 2 + c_range , mask = c_mask , other = 0.0 )
929- dy2s = tl .load (grad_out_ptr + grad_base + 1 * sphere_channels * 2 + c_range , mask = c_mask , other = 0.0 )
930- dy3s = tl .load (grad_out_ptr + grad_base + 3 * sphere_channels * 2 + c_range , mask = c_mask , other = 0.0 )
931- dy4s = tl .load (grad_out_ptr + grad_base + 8 * sphere_channels * 2 + c_range , mask = c_mask , other = 0.0 )
932- dy5s = tl .load (grad_out_ptr + grad_base + 6 * sphere_channels * 2 + c_range , mask = c_mask , other = 0.0 )
933- dy6s = tl .load (grad_out_ptr + grad_base + 2 * sphere_channels * 2 + c_range , mask = c_mask , other = 0.0 )
934- dy7s = tl .load (grad_out_ptr + grad_base + 4 * sphere_channels * 2 + c_range , mask = c_mask , other = 0.0 )
935- dy8s = tl .load (grad_out_ptr + grad_base + 7 * sphere_channels * 2 + c_range , mask = c_mask , other = 0.0 )
927+ dy0s = tl .load (
928+ grad_out_ptr + grad_base + 0 * sphere_channels * 2 + c_range ,
929+ mask = c_mask ,
930+ other = 0.0 ,
931+ )
932+ dy1s = tl .load (
933+ grad_out_ptr + grad_base + 5 * sphere_channels * 2 + c_range ,
934+ mask = c_mask ,
935+ other = 0.0 ,
936+ )
937+ dy2s = tl .load (
938+ grad_out_ptr + grad_base + 1 * sphere_channels * 2 + c_range ,
939+ mask = c_mask ,
940+ other = 0.0 ,
941+ )
942+ dy3s = tl .load (
943+ grad_out_ptr + grad_base + 3 * sphere_channels * 2 + c_range ,
944+ mask = c_mask ,
945+ other = 0.0 ,
946+ )
947+ dy4s = tl .load (
948+ grad_out_ptr + grad_base + 8 * sphere_channels * 2 + c_range ,
949+ mask = c_mask ,
950+ other = 0.0 ,
951+ )
952+ dy5s = tl .load (
953+ grad_out_ptr + grad_base + 6 * sphere_channels * 2 + c_range ,
954+ mask = c_mask ,
955+ other = 0.0 ,
956+ )
957+ dy6s = tl .load (
958+ grad_out_ptr + grad_base + 2 * sphere_channels * 2 + c_range ,
959+ mask = c_mask ,
960+ other = 0.0 ,
961+ )
962+ dy7s = tl .load (
963+ grad_out_ptr + grad_base + 4 * sphere_channels * 2 + c_range ,
964+ mask = c_mask ,
965+ other = 0.0 ,
966+ )
967+ dy8s = tl .load (
968+ grad_out_ptr + grad_base + 7 * sphere_channels * 2 + c_range ,
969+ mask = c_mask ,
970+ other = 0.0 ,
971+ )
936972
937973 # Load tgt part (second C channels, offset by sphere_channels)
938- dy0t = tl .load (grad_out_ptr + grad_base + 0 * sphere_channels * 2 + sphere_channels + c_range , mask = c_mask , other = 0.0 )
939- dy1t = tl .load (grad_out_ptr + grad_base + 5 * sphere_channels * 2 + sphere_channels + c_range , mask = c_mask , other = 0.0 )
940- dy2t = tl .load (grad_out_ptr + grad_base + 1 * sphere_channels * 2 + sphere_channels + c_range , mask = c_mask , other = 0.0 )
941- dy3t = tl .load (grad_out_ptr + grad_base + 3 * sphere_channels * 2 + sphere_channels + c_range , mask = c_mask , other = 0.0 )
942- dy4t = tl .load (grad_out_ptr + grad_base + 8 * sphere_channels * 2 + sphere_channels + c_range , mask = c_mask , other = 0.0 )
943- dy5t = tl .load (grad_out_ptr + grad_base + 6 * sphere_channels * 2 + sphere_channels + c_range , mask = c_mask , other = 0.0 )
944- dy6t = tl .load (grad_out_ptr + grad_base + 2 * sphere_channels * 2 + sphere_channels + c_range , mask = c_mask , other = 0.0 )
945- dy7t = tl .load (grad_out_ptr + grad_base + 4 * sphere_channels * 2 + sphere_channels + c_range , mask = c_mask , other = 0.0 )
946- dy8t = tl .load (grad_out_ptr + grad_base + 7 * sphere_channels * 2 + sphere_channels + c_range , mask = c_mask , other = 0.0 )
974+ dy0t = tl .load (
975+ grad_out_ptr
976+ + grad_base
977+ + 0 * sphere_channels * 2
978+ + sphere_channels
979+ + c_range ,
980+ mask = c_mask ,
981+ other = 0.0 ,
982+ )
983+ dy1t = tl .load (
984+ grad_out_ptr
985+ + grad_base
986+ + 5 * sphere_channels * 2
987+ + sphere_channels
988+ + c_range ,
989+ mask = c_mask ,
990+ other = 0.0 ,
991+ )
992+ dy2t = tl .load (
993+ grad_out_ptr
994+ + grad_base
995+ + 1 * sphere_channels * 2
996+ + sphere_channels
997+ + c_range ,
998+ mask = c_mask ,
999+ other = 0.0 ,
1000+ )
1001+ dy3t = tl .load (
1002+ grad_out_ptr
1003+ + grad_base
1004+ + 3 * sphere_channels * 2
1005+ + sphere_channels
1006+ + c_range ,
1007+ mask = c_mask ,
1008+ other = 0.0 ,
1009+ )
1010+ dy4t = tl .load (
1011+ grad_out_ptr
1012+ + grad_base
1013+ + 8 * sphere_channels * 2
1014+ + sphere_channels
1015+ + c_range ,
1016+ mask = c_mask ,
1017+ other = 0.0 ,
1018+ )
1019+ dy5t = tl .load (
1020+ grad_out_ptr
1021+ + grad_base
1022+ + 6 * sphere_channels * 2
1023+ + sphere_channels
1024+ + c_range ,
1025+ mask = c_mask ,
1026+ other = 0.0 ,
1027+ )
1028+ dy6t = tl .load (
1029+ grad_out_ptr
1030+ + grad_base
1031+ + 2 * sphere_channels * 2
1032+ + sphere_channels
1033+ + c_range ,
1034+ mask = c_mask ,
1035+ other = 0.0 ,
1036+ )
1037+ dy7t = tl .load (
1038+ grad_out_ptr
1039+ + grad_base
1040+ + 4 * sphere_channels * 2
1041+ + sphere_channels
1042+ + c_range ,
1043+ mask = c_mask ,
1044+ other = 0.0 ,
1045+ )
1046+ dy8t = tl .load (
1047+ grad_out_ptr
1048+ + grad_base
1049+ + 7 * sphere_channels * 2
1050+ + sphere_channels
1051+ + c_range ,
1052+ mask = c_mask ,
1053+ other = 0.0 ,
1054+ )
9471055
9481056 # Load node features (L-major order)
949- xs0 = tl .load (x_ptr + idx0 * x_stride_n + 0 * x_stride_m + c_range , mask = c_mask , other = 0.0 )
950- xs1 = tl .load (x_ptr + idx0 * x_stride_n + 1 * x_stride_m + c_range , mask = c_mask , other = 0.0 )
951- xs2 = tl .load (x_ptr + idx0 * x_stride_n + 2 * x_stride_m + c_range , mask = c_mask , other = 0.0 )
952- xs3 = tl .load (x_ptr + idx0 * x_stride_n + 3 * x_stride_m + c_range , mask = c_mask , other = 0.0 )
953- xs4 = tl .load (x_ptr + idx0 * x_stride_n + 4 * x_stride_m + c_range , mask = c_mask , other = 0.0 )
954- xs5 = tl .load (x_ptr + idx0 * x_stride_n + 5 * x_stride_m + c_range , mask = c_mask , other = 0.0 )
955- xs6 = tl .load (x_ptr + idx0 * x_stride_n + 6 * x_stride_m + c_range , mask = c_mask , other = 0.0 )
956- xs7 = tl .load (x_ptr + idx0 * x_stride_n + 7 * x_stride_m + c_range , mask = c_mask , other = 0.0 )
957- xs8 = tl .load (x_ptr + idx0 * x_stride_n + 8 * x_stride_m + c_range , mask = c_mask , other = 0.0 )
958-
959- xt0 = tl .load (x_ptr + idx1 * x_stride_n + 0 * x_stride_m + c_range , mask = c_mask , other = 0.0 )
960- xt1 = tl .load (x_ptr + idx1 * x_stride_n + 1 * x_stride_m + c_range , mask = c_mask , other = 0.0 )
961- xt2 = tl .load (x_ptr + idx1 * x_stride_n + 2 * x_stride_m + c_range , mask = c_mask , other = 0.0 )
962- xt3 = tl .load (x_ptr + idx1 * x_stride_n + 3 * x_stride_m + c_range , mask = c_mask , other = 0.0 )
963- xt4 = tl .load (x_ptr + idx1 * x_stride_n + 4 * x_stride_m + c_range , mask = c_mask , other = 0.0 )
964- xt5 = tl .load (x_ptr + idx1 * x_stride_n + 5 * x_stride_m + c_range , mask = c_mask , other = 0.0 )
965- xt6 = tl .load (x_ptr + idx1 * x_stride_n + 6 * x_stride_m + c_range , mask = c_mask , other = 0.0 )
966- xt7 = tl .load (x_ptr + idx1 * x_stride_n + 7 * x_stride_m + c_range , mask = c_mask , other = 0.0 )
967- xt8 = tl .load (x_ptr + idx1 * x_stride_n + 8 * x_stride_m + c_range , mask = c_mask , other = 0.0 )
1057+ xs0 = tl .load (
1058+ x_ptr + idx0 * x_stride_n + 0 * x_stride_m + c_range , mask = c_mask , other = 0.0
1059+ )
1060+ xs1 = tl .load (
1061+ x_ptr + idx0 * x_stride_n + 1 * x_stride_m + c_range , mask = c_mask , other = 0.0
1062+ )
1063+ xs2 = tl .load (
1064+ x_ptr + idx0 * x_stride_n + 2 * x_stride_m + c_range , mask = c_mask , other = 0.0
1065+ )
1066+ xs3 = tl .load (
1067+ x_ptr + idx0 * x_stride_n + 3 * x_stride_m + c_range , mask = c_mask , other = 0.0
1068+ )
1069+ xs4 = tl .load (
1070+ x_ptr + idx0 * x_stride_n + 4 * x_stride_m + c_range , mask = c_mask , other = 0.0
1071+ )
1072+ xs5 = tl .load (
1073+ x_ptr + idx0 * x_stride_n + 5 * x_stride_m + c_range , mask = c_mask , other = 0.0
1074+ )
1075+ xs6 = tl .load (
1076+ x_ptr + idx0 * x_stride_n + 6 * x_stride_m + c_range , mask = c_mask , other = 0.0
1077+ )
1078+ xs7 = tl .load (
1079+ x_ptr + idx0 * x_stride_n + 7 * x_stride_m + c_range , mask = c_mask , other = 0.0
1080+ )
1081+ xs8 = tl .load (
1082+ x_ptr + idx0 * x_stride_n + 8 * x_stride_m + c_range , mask = c_mask , other = 0.0
1083+ )
1084+
1085+ xt0 = tl .load (
1086+ x_ptr + idx1 * x_stride_n + 0 * x_stride_m + c_range , mask = c_mask , other = 0.0
1087+ )
1088+ xt1 = tl .load (
1089+ x_ptr + idx1 * x_stride_n + 1 * x_stride_m + c_range , mask = c_mask , other = 0.0
1090+ )
1091+ xt2 = tl .load (
1092+ x_ptr + idx1 * x_stride_n + 2 * x_stride_m + c_range , mask = c_mask , other = 0.0
1093+ )
1094+ xt3 = tl .load (
1095+ x_ptr + idx1 * x_stride_n + 3 * x_stride_m + c_range , mask = c_mask , other = 0.0
1096+ )
1097+ xt4 = tl .load (
1098+ x_ptr + idx1 * x_stride_n + 4 * x_stride_m + c_range , mask = c_mask , other = 0.0
1099+ )
1100+ xt5 = tl .load (
1101+ x_ptr + idx1 * x_stride_n + 5 * x_stride_m + c_range , mask = c_mask , other = 0.0
1102+ )
1103+ xt6 = tl .load (
1104+ x_ptr + idx1 * x_stride_n + 6 * x_stride_m + c_range , mask = c_mask , other = 0.0
1105+ )
1106+ xt7 = tl .load (
1107+ x_ptr + idx1 * x_stride_n + 7 * x_stride_m + c_range , mask = c_mask , other = 0.0
1108+ )
1109+ xt8 = tl .load (
1110+ x_ptr + idx1 * x_stride_n + 8 * x_stride_m + c_range , mask = c_mask , other = 0.0
1111+ )
9681112
9691113 # dW[i,j] = sum_c (dy_src[i]*x_src[j] + dy_tgt[i]*x_tgt[j])
9701114 # L=0 block (1x1)
@@ -1355,15 +1499,33 @@ def permute_wigner_inv_edge_to_node_bwd_dw_kernel(
13551499 dy8 = tl .load (DY_ptr + dy_base + 8 * C + c_range , mask = c_mask , other = 0.0 )
13561500
13571501 # Load x from M-major positions using M_TO_L_GATHER_IDX = [0,5,1,3,8,6,2,4,7]
1358- x0 = tl .load (X_ptr + x_base + 0 * C + c_range , mask = c_mask , other = 0.0 ) # L=0 <- M=0
1359- x1 = tl .load (X_ptr + x_base + 5 * C + c_range , mask = c_mask , other = 0.0 ) # L=1 <- M=5
1360- x2 = tl .load (X_ptr + x_base + 1 * C + c_range , mask = c_mask , other = 0.0 ) # L=2 <- M=1
1361- x3 = tl .load (X_ptr + x_base + 3 * C + c_range , mask = c_mask , other = 0.0 ) # L=3 <- M=3
1362- x4 = tl .load (X_ptr + x_base + 8 * C + c_range , mask = c_mask , other = 0.0 ) # L=4 <- M=8
1363- x5 = tl .load (X_ptr + x_base + 6 * C + c_range , mask = c_mask , other = 0.0 ) # L=5 <- M=6
1364- x6 = tl .load (X_ptr + x_base + 2 * C + c_range , mask = c_mask , other = 0.0 ) # L=6 <- M=2
1365- x7 = tl .load (X_ptr + x_base + 4 * C + c_range , mask = c_mask , other = 0.0 ) # L=7 <- M=4
1366- x8 = tl .load (X_ptr + x_base + 7 * C + c_range , mask = c_mask , other = 0.0 ) # L=8 <- M=7
1502+ x0 = tl .load (
1503+ X_ptr + x_base + 0 * C + c_range , mask = c_mask , other = 0.0
1504+ ) # L=0 <- M=0
1505+ x1 = tl .load (
1506+ X_ptr + x_base + 5 * C + c_range , mask = c_mask , other = 0.0
1507+ ) # L=1 <- M=5
1508+ x2 = tl .load (
1509+ X_ptr + x_base + 1 * C + c_range , mask = c_mask , other = 0.0
1510+ ) # L=2 <- M=1
1511+ x3 = tl .load (
1512+ X_ptr + x_base + 3 * C + c_range , mask = c_mask , other = 0.0
1513+ ) # L=3 <- M=3
1514+ x4 = tl .load (
1515+ X_ptr + x_base + 8 * C + c_range , mask = c_mask , other = 0.0
1516+ ) # L=4 <- M=8
1517+ x5 = tl .load (
1518+ X_ptr + x_base + 6 * C + c_range , mask = c_mask , other = 0.0
1519+ ) # L=5 <- M=6
1520+ x6 = tl .load (
1521+ X_ptr + x_base + 2 * C + c_range , mask = c_mask , other = 0.0
1522+ ) # L=6 <- M=2
1523+ x7 = tl .load (
1524+ X_ptr + x_base + 4 * C + c_range , mask = c_mask , other = 0.0
1525+ ) # L=7 <- M=4
1526+ x8 = tl .load (
1527+ X_ptr + x_base + 7 * C + c_range , mask = c_mask , other = 0.0
1528+ ) # L=8 <- M=7
13671529
13681530 # L=0 block (1x1): dW[0,0] = sum_c dy[0,c] * x[0,c]
13691531 dw_00 = tl .sum (dy0 * x0 )
0 commit comments