Skip to content

Commit db741c7

Browse files
pytorchbotIsalia20
andauthored
[MPS] fix compiling of SDPA producing nan results (#178009)
[MPS] fix compiling of SDPA producing nan results (#175481) Fixes #171764 Took me a while to figure out wth was going wrong. Mini reproducer: ```python import torch # (uint / 65536) % non_power of 2, gives wrong result lib = torch.mps.compile_shader(''' kernel void func(device int* out, uint idx [[thread_position_in_grid]]) { out[idx] = (idx / 65536) % 6; } ''') out = torch.empty(128, device='mps', dtype=torch.int32) lib.func(out) # Every value should be 0 since xindex/65536 == 0 for xindex in [0,127] for i in [0, 5, 6, 7, 63, 64]: print(f"{i=} got {out[i].item()}") ``` Same purely in swift ```swift import Metal let device = MTLCreateSystemDefaultDevice()! let queue = device.makeCommandQueue()! let shaderSource = """ kernel void func(device int* out [[buffer(0)]], uint idx [[thread_position_in_grid]]) { out[idx] = (idx / 65536) % 6; } """ let library = try device.makeLibrary(source: shaderSource, options: nil) let function = library.makeFunction(name: "func")! let pipeline = try device.makeComputePipelineState(function: function) let count = 128 let buffer = device.makeBuffer(length: count * MemoryLayout<Int32>.stride, options: .storageModeShared)! let cmdBuf = queue.makeCommandBuffer()! let encoder = cmdBuf.makeComputeCommandEncoder()! encoder.setComputePipelineState(pipeline) encoder.setBuffer(buffer, offset: 0, index: 0) encoder.dispatchThreads( MTLSizeMake(count, 1, 1), threadsPerThreadgroup: MTLSizeMake(min(count, pipeline.maxTotalThreadsPerThreadgroup), 1, 1) ) encoder.endEncoding() cmdBuf.commit() cmdBuf.waitUntilCompleted() let ptr = buffer.contents().bindMemory(to: Int32.self, capacity: count) for i in [0, 5, 6, 7, 63, 64] { print("i=\(i) got \(ptr[i])") } ``` Pull Request resolved: #175481 Approved by: https://github.com/malfet (cherry picked from commit 3a9554c) Co-authored-by: Isalia20 <irakli.salia854@gmail.com>
1 parent 483b55d commit db741c7

4 files changed

Lines changed: 52 additions & 0 deletions

File tree

c10/metal/utils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,18 @@ inline common_dtype<T, U> floor_divide(T x, U y) {
189189
return ::metal::floor(x / y);
190190
}
191191

192+
// Workaround for Metal compiler bug: the compiler produces wrong results
193+
// when optimizing fused (x / A) % B expressions for integral types.
194+
template <
195+
typename T,
196+
typename U,
197+
::metal::enable_if_t<
198+
is_scalar_integral_v<T> && is_scalar_integral_v<U>,
199+
bool> = true>
200+
inline common_dtype<T, U> safe_mod(volatile T x, U y) {
201+
return x % y;
202+
}
203+
192204
// fmod
193205
template <
194206
typename T,

test/inductor/test_mps_basic.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,25 @@ def fn(a, b):
191191
),
192192
)
193193

194+
def test_sdpa_split_qkv(self):
195+
# regression test for metal compiler bug where fused (x / A) % B
196+
# produces wrong results, causing incorrect reads from non-contiguous.
197+
n_head, n_embd, seq_len = 6, 384, 1024
198+
x = torch.randn(16, seq_len, n_embd, device="mps")
199+
c_attn = torch.nn.Linear(n_embd, 3 * n_embd).to("mps").eval()
200+
qkv = c_attn(x)
201+
q, k, v = qkv.split(n_embd, dim=2)
202+
q = q.view(16, seq_len, n_head, n_embd // n_head).transpose(1, 2)
203+
k = k.view(16, seq_len, n_head, n_embd // n_head).transpose(1, 2)
204+
v = v.view(16, seq_len, n_head, n_embd // n_head).transpose(1, 2)
205+
206+
def fn(q, k, v):
207+
return torch.nn.functional.scaled_dot_product_attention(
208+
q, k, v, is_causal=True
209+
)
210+
211+
self.common(fn, (q, k, v), atol=1e-4, rtol=1e-4, check_lowp=False)
212+
194213

195214
class MPSBasicTestsAOTI(TestCase):
196215
def check_model(self, m, inp, dynamic_shapes=None):

test/test_mps.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13313,6 +13313,22 @@ def test_metal_error_buffer(self):
1331313313
with self.assertRaisesRegex(RuntimeError, "Index .* exceeds limit"):
1331413314
torch.mps.synchronize()
1331513315

13316+
def test_metal_compiler_bug_workaround(self):
13317+
# (uint / 65536) % non_power of 2, gives wrong result without safe_mod
13318+
lib = torch.mps.compile_shader('''
13319+
#include <c10/metal/utils.h>
13320+
13321+
kernel void func(device int* out, uint idx [[thread_position_in_grid]]) {
13322+
out[idx] = c10::metal::safe_mod((idx / 65536), 6);
13323+
}
13324+
''')
13325+
out = torch.empty(128, device='mps', dtype=torch.int32)
13326+
lib.func(out)
13327+
# Every value should be 0 since xindex/65536 == 0 for xindex in [0,127]
13328+
for i in [0, 5, 6, 7, 63, 64]:
13329+
self.assertEqual(out[i], 0)
13330+
13331+
1331613332

1331713333
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
1331813334
# This requires mps to be properly registered in the device generic test framework which is not the

torch/_inductor/codegen/mps.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ def _print_FloorDiv(self, expr: sympy.Expr) -> str:
8080

8181
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
8282
x, div, mod = expr.args
83+
# Workaround for Metal compiler bug with fused (x / A) % B, see PR 175481
84+
use_safe_mod = div == 65536 and (mod & (mod - 1)) != 0
85+
8386
x = self.doprint(x)
8487
if div != 1:
8588
div = self.doprint(div)
@@ -88,6 +91,8 @@ def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
8891
else:
8992
x = f"metal::floor({x}) / ({div})"
9093
mod = self.doprint(mod)
94+
if use_safe_mod:
95+
return f"c10::metal::safe_mod({x}, {mod})"
9196
return f"({x}) % ({mod})"
9297

9398
def _print_Min(self, expr: sympy.Expr) -> str:

0 commit comments

Comments
 (0)