11import dataclasses
22from functools import lru_cache
33import logging
4+ import os
45import re
56import subprocess
67from typing import Optional
@@ -83,10 +84,21 @@ def get_rocm_gpu_arch() -> str:
8384 logger = logging .getLogger (__name__ )
8485 try :
8586 if torch .version .hip :
86- result = subprocess .run (["rocminfo" ], capture_output = True , text = True )
87- match = re .search (r"Name:\s+gfx([a-zA-Z\d]+)" , result .stdout )
87+ # On Windows, use hipinfo.exe; on Linux, use rocminfo
88+ if os .name == "nt" :
89+ cmd = ["hipinfo.exe" ]
90+ arch_pattern = r"gcnArchName:\s+(gfx[a-zA-Z\d]+)"
91+ else :
92+ cmd = ["rocminfo" ]
93+ arch_pattern = r"Name:\s+gfx([a-zA-Z\d]+)"
94+
95+ result = subprocess .run (cmd , capture_output = True , text = True )
96+ match = re .search (arch_pattern , result .stdout )
8897 if match :
89- return "gfx" + match .group (1 )
98+ if os .name == "nt" :
99+ return match .group (1 )
100+ else :
101+ return "gfx" + match .group (1 )
90102 else :
91103 return "unknown"
92104 else :
@@ -107,8 +119,17 @@ def get_rocm_warpsize() -> int:
107119 logger = logging .getLogger (__name__ )
108120 try :
109121 if torch .version .hip :
110- result = subprocess .run (["rocminfo" ], capture_output = True , text = True )
111- match = re .search (r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)" , result .stdout )
122+ # On Windows, use hipinfo.exe; on Linux, use rocminfo
123+ if os .name == "nt" :
124+ cmd = ["hipinfo.exe" ]
125+ # hipinfo.exe output format: "warpSize: 32" or "warpSize: 64"
126+ warp_pattern = r"warpSize:\s+(\d+)"
127+ else :
128+ cmd = ["rocminfo" ]
129+ warp_pattern = r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)"
130+
131+ result = subprocess .run (cmd , capture_output = True , text = True )
132+ match = re .search (warp_pattern , result .stdout )
112133 if match :
113134 return int (match .group (1 ))
114135 else :
0 commit comments