Commit 12234f8
authored
[torchlib] Add missing dtype parameter to aten_mean_dim (#2885)
Fixes #2884
`aten_mean_dim` and `aten_mean_dim_complex` are missing the `dtype`
keyword argument from their signatures, even though the ATen schema
documents it (`ScalarType? dtype=None`). This causes a `TypeError` when
PyTorch lowers `aten::mean.dim` with an explicit `dtype` - which happens
for any model using `GlobalAveragePooling2D` (Keras/PyTorch).
- Add `dtype: int = -1` to `aten_mean_dim`, with `op.Cast` when dtype is
specified
- Add `dtype: int = -1` to `aten_mean_dim_complex`, raising
`NotImplementedError` for complex tensors
Follows the same pattern used by `aten_sum_dim_IntList` and
`aten_sum_dim_IntList_complex`.1 parent 1ef0ec9 commit 12234f8
1 file changed
+14
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6231 | 6231 | | |
6232 | 6232 | | |
6233 | 6233 | | |
6234 | | - | |
| 6234 | + | |
6235 | 6235 | | |
6236 | 6236 | | |
6237 | 6237 | | |
6238 | 6238 | | |
6239 | 6239 | | |
6240 | 6240 | | |
6241 | 6241 | | |
| 6242 | + | |
| 6243 | + | |
| 6244 | + | |
| 6245 | + | |
6242 | 6246 | | |
6243 | 6247 | | |
6244 | 6248 | | |
6245 | 6249 | | |
6246 | | - | |
| 6250 | + | |
| 6251 | + | |
| 6252 | + | |
6247 | 6253 | | |
6248 | 6254 | | |
6249 | 6255 | | |
| |||
6254 | 6260 | | |
6255 | 6261 | | |
6256 | 6262 | | |
| 6263 | + | |
| 6264 | + | |
| 6265 | + | |
| 6266 | + | |
| 6267 | + | |
| 6268 | + | |
6257 | 6269 | | |
6258 | 6270 | | |
6259 | 6271 | | |
| |||
0 commit comments