@@ -151,6 +151,58 @@ def atomic_cas_kernel(
151151 return x
152152
153153
154+ # 2D kernels for tensor descriptor atomic tests (TD requires ndim >= 2 + static_shapes)
155+
156+
157+ @helion .kernel (static_shapes = True )
158+ def atomic_add_2d_td_kernel (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
159+ for i , j in hl .tile ([x .size (0 ), x .size (1 )]):
160+ hl .atomic_add (x , [i , j ], y [i , j ])
161+ return x
162+
163+
164+ @helion .kernel (static_shapes = True )
165+ def atomic_and_2d_td_kernel (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
166+ for i , j in hl .tile ([x .size (0 ), x .size (1 )]):
167+ hl .atomic_and (x , [i , j ], y [i , j ])
168+ return x
169+
170+
171+ @helion .kernel (static_shapes = True )
172+ def atomic_or_2d_td_kernel (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
173+ for i , j in hl .tile ([x .size (0 ), x .size (1 )]):
174+ hl .atomic_or (x , [i , j ], y [i , j ])
175+ return x
176+
177+
178+ @helion .kernel (static_shapes = True )
179+ def atomic_xor_2d_td_kernel (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
180+ for i , j in hl .tile ([x .size (0 ), x .size (1 )]):
181+ hl .atomic_xor (x , [i , j ], y [i , j ])
182+ return x
183+
184+
185+ @helion .kernel (static_shapes = True )
186+ def atomic_max_2d_td_kernel (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
187+ for i , j in hl .tile ([x .size (0 ), x .size (1 )]):
188+ hl .atomic_max (x , [i , j ], y [i , j ])
189+ return x
190+
191+
192+ @helion .kernel (static_shapes = True )
193+ def atomic_min_2d_td_kernel (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
194+ for i , j in hl .tile ([x .size (0 ), x .size (1 )]):
195+ hl .atomic_min (x , [i , j ], y [i , j ])
196+ return x
197+
198+
199+ @helion .kernel (static_shapes = True )
200+ def atomic_xchg_2d_td_kernel (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
201+ for i , j in hl .tile ([x .size (0 ), x .size (1 )]):
202+ hl .atomic_xchg (x , [i , j ], y [i , j ])
203+ return x
204+
205+
154206@onlyBackends (["triton" , "cute" , "pallas" ])
155207class TestAtomicOperations (RefEagerTestBase , TestCase ):
156208 def test_basic_atomic_add (self ):
@@ -425,30 +477,8 @@ def test_atomic_cas(self):
425477
426478 @onlyBackends ("triton" )
427479 @skipIfRocm ("Tensor descriptor not supported on ROCm" )
428- def test_atomic_add_tensor_descriptor (self ):
429- """Test that atomic_add with tensor_descriptor indexing generates desc.atomic_add."""
430-
431- @helion .kernel (
432- config = helion .Config (
433- block_sizes = [64 , 64 ],
434- indexing = "tensor_descriptor" ,
435- atomic_indexing = "tensor_descriptor" ,
436- ),
437- static_shapes = True ,
438- )
439- def atomic_add_td_kernel (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
440- for i , j in hl .tile ([x .size (0 ), x .size (1 )]):
441- hl .atomic_add (x , [i , j ], y [i , j ])
442- return x
443-
444- M , N = 128 , 64
445- x = torch .zeros (M , N , device = DEVICE , dtype = torch .float32 )
446- y = torch .ones (M , N , device = DEVICE , dtype = torch .float32 )
447- code , result = code_and_output (atomic_add_td_kernel , (x , y ))
448- expected = torch .ones (M , N , device = DEVICE , dtype = torch .float32 )
449- torch .testing .assert_close (result , expected )
450- self .assertIn ("desc.atomic_add(" , code )
451- self .assertNotIn ("tl.atomic_add" , code )
480+ def test_atomic_td_fallbacks (self ):
481+ """Test that tensor_descriptor atomics fall back to pointer when needed."""
452482
453483 # Return value consumed: should fall back to pointer
454484 @helion .kernel (
@@ -466,14 +496,14 @@ def atomic_add_td_prev_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
466496 out [i , j ] = prev
467497 return out
468498
469- x2 = torch . zeros ( M , N , device = DEVICE , dtype = torch . float32 )
470- y2 = torch .ones (M , N , device = DEVICE , dtype = torch .float32 )
471- code2 , result2 = code_and_output ( atomic_add_td_prev_kernel , ( x2 , y2 ) )
472- # prev should be zeros (the old values before adding ones )
473- expected2 = torch .zeros (M , N , device = DEVICE , dtype = torch .float32 )
474- torch .testing .assert_close (result2 , expected2 )
475- self .assertIn ("tl.atomic_add" , code2 )
476- self .assertNotIn ("desc.atomic_add(" , code2 )
499+ M , N = 128 , 64
500+ x = torch .zeros (M , N , device = DEVICE , dtype = torch .float32 )
501+ y = torch . ones ( M , N , device = DEVICE , dtype = torch . float32 )
502+ code , result = code_and_output ( atomic_add_td_prev_kernel , ( x , y ) )
503+ expected = torch .zeros (M , N , device = DEVICE , dtype = torch .float32 )
504+ torch .testing .assert_close (result , expected )
505+ self .assertIn ("tl.atomic_add" , code )
506+ self .assertNotIn ("desc.atomic_add(" , code )
477507
478508 # Non-relaxed sem: should fall back to pointer
479509 @helion .kernel (
@@ -491,13 +521,13 @@ def atomic_add_td_release_kernel(
491521 hl .atomic_add (x , [i , j ], y [i , j ], sem = "release" )
492522 return x
493523
494- x3 = torch .zeros (M , N , device = DEVICE , dtype = torch .float32 )
495- y3 = torch .ones (M , N , device = DEVICE , dtype = torch .float32 )
496- code3 , result3 = code_and_output (atomic_add_td_release_kernel , (x3 , y3 ))
497- expected3 = torch .ones (M , N , device = DEVICE , dtype = torch .float32 )
498- torch .testing .assert_close (result3 , expected3 )
499- self .assertIn ("tl.atomic_add" , code3 )
500- self .assertNotIn ("desc.atomic_add(" , code3 )
524+ x2 = torch .zeros (M , N , device = DEVICE , dtype = torch .float32 )
525+ y2 = torch .ones (M , N , device = DEVICE , dtype = torch .float32 )
526+ code2 , result2 = code_and_output (atomic_add_td_release_kernel , (x2 , y2 ))
527+ expected2 = torch .ones (M , N , device = DEVICE , dtype = torch .float32 )
528+ torch .testing .assert_close (result2 , expected2 )
529+ self .assertIn ("tl.atomic_add" , code2 )
530+ self .assertNotIn ("desc.atomic_add(" , code2 )
501531
502532 @onlyBackends ("triton" )
503533 @skipIfRocm ("Tensor descriptor not supported on ROCm" )
@@ -536,6 +566,81 @@ def two_atomic_adds(
536566 self .assertNotIn ("out1_desc" , code )
537567 self .assertNotIn ("tl.atomic_add(out2" , code )
538568
569+ @onlyBackends ("triton" )
570+ @skipIfRocm ("Tensor descriptor not supported on ROCm" )
571+ def test_atomic_ops_tensor_descriptor (self ):
572+ """Test all TMA-supported atomic ops generate desc.atomic_{op} codegen."""
573+ M , N = 128 , 64
574+ td_config = {
575+ "block_sizes" : [64 , 64 ],
576+ "indexing" : "tensor_descriptor" ,
577+ "atomic_indexing" : "tensor_descriptor" ,
578+ }
579+ # (op_name, kernel, x, y, expected)
580+ cases = [
581+ (
582+ "add" ,
583+ atomic_add_2d_td_kernel ,
584+ torch .zeros (M , N , device = DEVICE , dtype = torch .float32 ),
585+ torch .ones (M , N , device = DEVICE , dtype = torch .float32 ),
586+ torch .ones (M , N , device = DEVICE , dtype = torch .float32 ),
587+ ),
588+ (
589+ "and" ,
590+ atomic_and_2d_td_kernel ,
591+ torch .full ((M , N ), 0b1111 , device = DEVICE , dtype = torch .int32 ),
592+ torch .full ((M , N ), 0b1010 , device = DEVICE , dtype = torch .int32 ),
593+ torch .full ((M , N ), 0b1010 , device = DEVICE , dtype = torch .int32 ),
594+ ),
595+ (
596+ "or" ,
597+ atomic_or_2d_td_kernel ,
598+ torch .zeros (M , N , device = DEVICE , dtype = torch .int32 ),
599+ torch .full ((M , N ), 0b1010 , device = DEVICE , dtype = torch .int32 ),
600+ torch .full ((M , N ), 0b1010 , device = DEVICE , dtype = torch .int32 ),
601+ ),
602+ (
603+ "xor" ,
604+ atomic_xor_2d_td_kernel ,
605+ torch .full ((M , N ), 0b1010 , device = DEVICE , dtype = torch .int32 ),
606+ torch .full ((M , N ), 0b1100 , device = DEVICE , dtype = torch .int32 ),
607+ torch .full ((M , N ), 0b0110 , device = DEVICE , dtype = torch .int32 ),
608+ ),
609+ (
610+ "max" ,
611+ atomic_max_2d_td_kernel ,
612+ torch .ones (M , N , device = DEVICE , dtype = torch .int32 ),
613+ torch .full ((M , N ), 5 , device = DEVICE , dtype = torch .int32 ),
614+ torch .full ((M , N ), 5 , device = DEVICE , dtype = torch .int32 ),
615+ ),
616+ (
617+ "min" ,
618+ atomic_min_2d_td_kernel ,
619+ torch .full ((M , N ), 10 , device = DEVICE , dtype = torch .int32 ),
620+ torch .full ((M , N ), 3 , device = DEVICE , dtype = torch .int32 ),
621+ torch .full ((M , N ), 3 , device = DEVICE , dtype = torch .int32 ),
622+ ),
623+ ]
624+ for op_name , kernel , x , y , expected in cases :
625+ with self .subTest (op = op_name ):
626+ code , result = code_and_output (kernel , (x , y ), ** td_config )
627+ torch .testing .assert_close (result , expected )
628+ self .assertIn (f"desc.atomic_{ op_name } (" , code )
629+ self .assertNotIn (f"tl.atomic_{ op_name } " , code )
630+
631+ # xchg is NOT a TMA reduction op — should fall back to pointer
632+ with self .subTest (op = "xchg_fallback" ):
633+ x = torch .zeros (M , N , device = DEVICE , dtype = torch .int32 )
634+ y = torch .ones (M , N , device = DEVICE , dtype = torch .int32 )
635+ code , result = code_and_output (
636+ atomic_xchg_2d_td_kernel , (x , y ), ** td_config
637+ )
638+ torch .testing .assert_close (
639+ result , torch .ones (M , N , device = DEVICE , dtype = torch .int32 )
640+ )
641+ self .assertIn ("tl.atomic_xchg" , code )
642+ self .assertNotIn ("desc.atomic_xchg" , code )
643+
539644
540645if __name__ == "__main__" :
541646 unittest .main ()
0 commit comments