Commit db741c7
[MPS] fix compiling of SDPA producing nan results (#178009)
[MPS] fix compiling of SDPA producing nan results (#175481)
Fixes #171764
Took me a while to figure out wth was going wrong.
Mini reproducer:
```python
import torch
# (uint / 65536) % non_power of 2, gives wrong result
lib = torch.mps.compile_shader('''
kernel void func(device int* out, uint idx [[thread_position_in_grid]]) {
out[idx] = (idx / 65536) % 6;
}
''')
out = torch.empty(128, device='mps', dtype=torch.int32)
lib.func(out)
# Every value should be 0 since xindex/65536 == 0 for xindex in [0,127]
for i in [0, 5, 6, 7, 63, 64]:
print(f"{i=} got {out[i].item()}")
```
Same purely in swift
```swift
import Metal
let device = MTLCreateSystemDefaultDevice()!
let queue = device.makeCommandQueue()!
let shaderSource = """
kernel void func(device int* out [[buffer(0)]], uint idx [[thread_position_in_grid]]) {
out[idx] = (idx / 65536) % 6;
}
"""
let library = try device.makeLibrary(source: shaderSource, options: nil)
let function = library.makeFunction(name: "func")!
let pipeline = try device.makeComputePipelineState(function: function)
let count = 128
let buffer = device.makeBuffer(length: count * MemoryLayout<Int32>.stride, options: .storageModeShared)!
let cmdBuf = queue.makeCommandBuffer()!
let encoder = cmdBuf.makeComputeCommandEncoder()!
encoder.setComputePipelineState(pipeline)
encoder.setBuffer(buffer, offset: 0, index: 0)
encoder.dispatchThreads(
MTLSizeMake(count, 1, 1),
threadsPerThreadgroup: MTLSizeMake(min(count, pipeline.maxTotalThreadsPerThreadgroup), 1, 1)
)
encoder.endEncoding()
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
let ptr = buffer.contents().bindMemory(to: Int32.self, capacity: count)
for i in [0, 5, 6, 7, 63, 64] {
print("i=\(i) got \(ptr[i])")
}
```
Pull Request resolved: #175481
Approved by: https://github.com/malfet
(cherry picked from commit 3a9554c)
Co-authored-by: Isalia20 <irakli.salia854@gmail.com>1 parent 483b55d commit db741c7
4 files changed
Lines changed: 52 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
189 | 189 | | |
190 | 190 | | |
191 | 191 | | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
192 | 204 | | |
193 | 205 | | |
194 | 206 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
191 | 191 | | |
192 | 192 | | |
193 | 193 | | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
194 | 213 | | |
195 | 214 | | |
196 | 215 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
13313 | 13313 | | |
13314 | 13314 | | |
13315 | 13315 | | |
| 13316 | + | |
| 13317 | + | |
| 13318 | + | |
| 13319 | + | |
| 13320 | + | |
| 13321 | + | |
| 13322 | + | |
| 13323 | + | |
| 13324 | + | |
| 13325 | + | |
| 13326 | + | |
| 13327 | + | |
| 13328 | + | |
| 13329 | + | |
| 13330 | + | |
| 13331 | + | |
13316 | 13332 | | |
13317 | 13333 | | |
13318 | 13334 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
80 | 80 | | |
81 | 81 | | |
82 | 82 | | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
83 | 86 | | |
84 | 87 | | |
85 | 88 | | |
| |||
88 | 91 | | |
89 | 92 | | |
90 | 93 | | |
| 94 | + | |
| 95 | + | |
91 | 96 | | |
92 | 97 | | |
93 | 98 | | |
| |||
0 commit comments