Skip to content

Commit 3b233c1

Browse files
DamianSzwichtenbergdsawczuk-intCopilot
authored
Make SFT hardware-agnostic (#749)
Co-authored-by: Sawczuk, Daniel <daniel.sawczuk@intel.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 6c61642 commit 3b233c1

File tree

7 files changed

+187
-116
lines changed

7 files changed

+187
-116
lines changed

apps/sft/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,8 @@ async def train(self) -> None:
444444
# Move tensors to the appropriate device
445445
for k, v in batch.items():
446446
if isinstance(v, torch.Tensor):
447-
batch[k] = v.to("cuda") # TODO: hardcoded for now
447+
# self.device is set up in ForgeEngine
448+
batch[k] = v.to(self.device)
448449

449450
self.train_step(batch)
450451
# self.profiler.step()

src/forge/controller/provisioner.py

Lines changed: 84 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,79 @@
3131
logger.setLevel(logging.DEBUG)
3232

3333

34+
class DeviceProxy:
35+
"""A hardware-agnostic proxy using torch.accelerator.
36+
37+
Handles device counting and environment variable mapping for isolation.
38+
"""
39+
40+
# Mapping of PyTorch backend names to driver isolation variables
41+
_VISIBLE_DEVICES_ENV_MAP: dict[str, str] = {
42+
"cuda": "CUDA_VISIBLE_DEVICES",
43+
"xpu": "ZE_AFFINITY_MASK", # Intel Level Zero
44+
}
45+
46+
@staticmethod
47+
def is_available() -> bool:
48+
"""Check if any accelerator is available."""
49+
return torch.accelerator.is_available()
50+
51+
@staticmethod
52+
def get_device_count() -> int:
53+
"""Returns the number of available accelerator devices."""
54+
if not DeviceProxy.is_available():
55+
return 0
56+
return torch.accelerator.device_count()
57+
58+
@classmethod
59+
def get_visible_devices_env_var(cls) -> str | None:
60+
"""Returns the environment variable name used to mask devices.
61+
62+
Returns None if no accelerator is available or the backend is not supported.
63+
"""
64+
if not cls.is_available():
65+
return None
66+
accelerator = torch.accelerator.current_accelerator()
67+
if accelerator is None:
68+
return None
69+
return cls._VISIBLE_DEVICES_ENV_MAP.get(accelerator.type)
70+
71+
@classmethod
72+
def get_isolation_env_vars(cls, device_ids: list[str]) -> dict[str, str]:
73+
"""Returns environment variables needed to isolate specific device IDs.
74+
75+
Returns an empty dict if no isolation env var is available for this backend.
76+
"""
77+
env_var_name = cls.get_visible_devices_env_var()
78+
if env_var_name is None:
79+
return {}
80+
return {env_var_name: ",".join(device_ids)}
81+
82+
@classmethod
83+
def get_visible_devices_from_env(cls) -> set[int] | None:
84+
"""Parses visible devices from the appropriate environment variable.
85+
86+
Returns None if the variable is not set.
87+
Raises ValueError if the format is invalid.
88+
"""
89+
env_var = cls.get_visible_devices_env_var()
90+
if env_var is None:
91+
return None
92+
93+
env_value = os.environ.get(env_var, None)
94+
if env_value is None or not env_value.strip():
95+
return None
96+
97+
try:
98+
# For Intel Level Zero we support ZE_FLAT_DEVICE_HIERARCHY=flat
99+
return set(int(x.strip()) for x in env_value.split(",") if x.strip())
100+
except ValueError as e:
101+
raise ValueError(
102+
f"Invalid {env_var} format: '{env_value}'. "
103+
f"Expected comma-separated integers (e.g., '0,1,2'). Error: {e}"
104+
) from e
105+
106+
34107
def _get_port() -> str:
35108
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
36109
s.bind(("localhost", 0))
@@ -49,13 +122,8 @@ def get_info(self) -> tuple[str, str]:
49122

50123
@endpoint
51124
def get_gpu_count(self) -> int:
52-
"""Returns the number of GPUs available on this host."""
53-
try:
54-
gpu_count = torch.cuda.device_count()
55-
except Exception:
56-
# If torch is not available or CUDA is not available, assume no GPUs
57-
gpu_count = 0
58-
return gpu_count
125+
"""Returns the number of accelerator devices available on this host."""
126+
return DeviceProxy.get_device_count()
59127

60128

61129
class EnvSetter(Actor):
@@ -209,33 +277,15 @@ def __init__(self, cfg: ProvisionerConfig | None = None):
209277
# remove this once this is supported in Monarch.
210278
self._this_host_id = uuid.uuid1()
211279

212-
# For the local host, we may want to set CUDA_VISIBLE_DEVICES
280+
# For the local host, we may want to set device visibility
213281
# for small scale testing. We inherit the environment's
214-
# CUDA_VISIBLE_DEVICES **only for the local host** and not
215-
# for remote hosts.
216-
available_local_devices = None
217-
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
218-
if cuda_visible_devices is not None and cuda_visible_devices.strip():
219-
try:
220-
available_local_devices = set(
221-
int(x.strip()) for x in cuda_visible_devices.split(",") if x.strip()
222-
)
223-
except ValueError as e:
224-
raise ValueError(
225-
f"Invalid CUDA_VISIBLE_DEVICES format: '{cuda_visible_devices}'. "
226-
f"Expected comma-separated integers (e.g., '0,1,2'). Error: {e}"
227-
) from e
228-
229-
# Get the actual GPU count for the local host
230-
try:
231-
local_gpu_count = torch.cuda.device_count()
232-
except Exception:
233-
# If torch is not available or CUDA is not available, assume no GPUs
234-
local_gpu_count = 0
282+
# device visibility setting **only for the local host**.
283+
available_local_devices = DeviceProxy.get_visible_devices_from_env()
284+
local_device_count = DeviceProxy.get_device_count()
235285

236286
self._host_gpu_map = {
237287
self._this_host_id: GpuManager(
238-
available_local_devices, max_device_count=local_gpu_count
288+
available_local_devices, max_device_count=local_device_count
239289
),
240290
}
241291
self._proc_host_map = {}
@@ -298,7 +348,7 @@ async def get_proc_mesh(
298348
mesh_name: Name of the pre-allocated mesh to use.
299349
Must match a mesh name defined in the launcher config.
300350
with_gpus: Whether to include GPU allocations.
301-
This only adds the CUDA_VISIBLE_DEVICES environment variable.
351+
This only adds the hardware isolation environment variable.
302352
num_hosts: The number of hosts to allocate.
303353
If this is set, a remote allocation is created.
304354
If this is None, it uses the local host.
@@ -356,7 +406,9 @@ async def get_proc_mesh(
356406
# Set the PTD world size
357407
world_size = num_procs * (num_hosts or 1)
358408
env_vars["WORLD_SIZE"] = str(world_size)
359-
env_vars["CUDA_VISIBLE_DEVICES"] = ",".join(gpu_ids)
409+
410+
# Set device isolation using the appropriate environment variable
411+
env_vars.update(DeviceProxy.get_isolation_env_vars(gpu_ids))
360412

361413
# Inherit Forge-relevant environment variables from the system
362414
for env_var in all_env_vars():

src/forge/observability/perf_tracker.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,18 @@ class Tracer:
5454
Tracer with multi-step timing and optional memory tracking at start/stop boundaries.
5555
Steps only affect timing; memory is tracked from start() to stop().
5656
57-
Supports non-blocking CUDA timing via CUDA events and background polling threads.
57+
Supports non-blocking accelerator timing via torch events and background polling threads.
5858
Aggregation is handled externally by the metrics system via record_metric.
5959
6060
User must call start() and stop() explicitly.
6161
Supports reuse: after calling stop(), you may call start() again to begin a new timing session.
6262
6363
Local env flag DISABLE_PERF_METRICS can be used to skip all timing operations.
64-
Local env flag METRIC_TIMER_USES_GPU can be used to set CUDA timing.
64+
Local env flag METRIC_TIMER_USES_GPU can be used to set accelerator timing.
6565
6666
Args:
6767
prefix (str): Prefix for metric names, e.g. "my_prefix" -> "{my_prefix}/{step_name}/duration_avg_s".
68-
track_memory (bool): Whether to track CUDA memory usage. Defaults to False.
68+
track_memory (bool): Whether to track accelerator memory usage. Defaults to False.
6969
timer (str): Timing backend; "cpu" (default) or "gpu".
7070
7171
Example:
@@ -138,8 +138,8 @@ def start(self) -> None:
138138
else:
139139
# Env var not set - use the timer parameter
140140
use_gpu = self.time_with_gpu
141-
time_with_gpu_events = use_gpu and torch.cuda.is_available()
142-
self._timer = _TimerCUDA() if time_with_gpu_events else _TimerCPU()
141+
time_with_gpu_events = use_gpu and torch.accelerator.is_available()
142+
self._timer = _TimerGPU() if time_with_gpu_events else _TimerCPU()
143143
self._timer.start()
144144

145145
self._active = True
@@ -176,7 +176,7 @@ def stop(self) -> None:
176176
def _start_memory_tracking(self) -> None:
177177
is_outer_scope = not _is_memory_active()
178178
should_track = (
179-
self.track_memory and is_outer_scope and torch.cuda.is_available()
179+
self.track_memory and is_outer_scope and torch.accelerator.is_available()
180180
)
181181

182182
if self.track_memory and not is_outer_scope:
@@ -185,23 +185,23 @@ def _start_memory_tracking(self) -> None:
185185

186186
if should_track:
187187
_set_memory_active(True)
188-
torch.cuda.reset_peak_memory_stats()
189-
self._start_mem = torch.cuda.memory_allocated()
188+
torch.accelerator.reset_peak_memory_stats()
189+
self._start_mem = torch.accelerator.memory_allocated()
190190
self._memory_started = True
191191

192192
def _stop_memory_tracking(self) -> None:
193193
if not self._memory_started:
194194
return
195195

196-
end_mem = torch.cuda.memory_allocated()
196+
end_mem = torch.accelerator.memory_allocated()
197197
delta = (end_mem - self._start_mem) / 1024**3
198-
peak_mem = torch.cuda.max_memory_allocated() / 1024**3
198+
peak_mem = torch.accelerator.max_memory_allocated() / 1024**3
199199
record_metric(
200200
f"{self.prefix}/memory_delta_end_start_avg_gb", delta, Reduce.MEAN
201201
)
202202
record_metric(f"{self.prefix}/memory_peak_max_gb", peak_mem, Reduce.MAX)
203203
_set_memory_active(False)
204-
torch.cuda.reset_peak_memory_stats()
204+
torch.accelerator.reset_peak_memory_stats()
205205
self._memory_started = False
206206

207207
def _record_timing_metrics(
@@ -258,12 +258,12 @@ def get_all_durations(self) -> tuple[list[tuple[str, float]], float]:
258258
return self._durations[:], stop_step_ms
259259

260260

261-
class _TimerCUDA(_TimerProtocol):
262-
"""CUDA timing backend with non-blocking events and futures.
263-
Uses a thread pool to poll CUDA events asynchronously without blocking the main thread.
261+
class _TimerGPU(_TimerProtocol):
262+
"""Accelerator timing backend with non-blocking events and futures.
263+
Uses a thread pool to poll torch events asynchronously without blocking the main thread.
264264
265265
Example:
266-
timer = _TimerCUDA()
266+
timer = _TimerGPU()
267267
timer.start()
268268
# torch.mm(a, b) # ~100ms GPU
269269
timer.step("matmul")
@@ -272,36 +272,36 @@ class _TimerCUDA(_TimerProtocol):
272272
"""
273273

274274
def __init__(self, max_workers: int = 2) -> None:
275-
if not torch.cuda.is_available():
276-
raise RuntimeError("CUDA is not available for timing")
275+
if not torch.accelerator.is_available():
276+
raise RuntimeError("Accelerator is not available for timing")
277277
self._executor = ThreadPoolExecutor(max_workers=max_workers)
278278
self._futures: list[tuple[str, Future[float], int]] = (
279279
[]
280280
) # (name, future, submission_index)
281281
self._durations: list[tuple[str, float]] = []
282-
self._chain_start: torch.cuda.Event | None = None
282+
self._chain_start: torch.Event | None = None
283283

284284
def start(self) -> None:
285285
"""Call before any steps. Clear state for reuse; record initial event on current stream."""
286286
self._futures.clear()
287287
self._durations.clear()
288-
stream = torch.cuda.current_stream()
289-
start_event = torch.cuda.Event(enable_timing=True)
288+
stream = torch.accelerator.current_stream()
289+
start_event = torch.Event(enable_timing=True)
290290
start_event.record(stream)
291291
self._chain_start = start_event
292292

293293
def step(self, name: str) -> None:
294294
"""Mark the end of a GPU workload segment and start the next, submitting async polling.
295-
Records a CUDA end event on the current stream; a background thread polls completion.
295+
Records a torch end event on the current stream; a background thread polls completion.
296296
297297
Args:
298298
name: Label for this segment's duration
299299
"""
300300
if self._chain_start is None:
301301
raise ValueError("Timer must be started before calling step")
302302

303-
stream = torch.cuda.current_stream()
304-
end_event = torch.cuda.Event(enable_timing=True)
303+
stream = torch.accelerator.current_stream()
304+
end_event = torch.Event(enable_timing=True)
305305
end_event.record(stream)
306306

307307
future = self._executor.submit(self._poll_elapsed, self._chain_start, end_event)
@@ -312,9 +312,7 @@ def step(self, name: str) -> None:
312312

313313
self._chain_start = end_event
314314

315-
def _poll_elapsed(
316-
self, start_event: torch.cuda.Event, end_event: torch.cuda.Event
317-
) -> float:
315+
def _poll_elapsed(self, start_event: torch.Event, end_event: torch.Event) -> float:
318316
"""Compute elapsed time after polling with backoff."""
319317
# Poll until ready
320318
sleep_time = 0.001 # Start at 1ms
@@ -388,8 +386,8 @@ def trace(
388386
389387
Args:
390388
prefix (str): Prefix for metric names
391-
track_memory (bool): Whether to track CUDA memory usage. Defaults to False.
392-
timer (str): Timing backend; "cpu" (default) or "gpu" (requires CUDA).
389+
track_memory (bool): Whether to track memory usage. Defaults to False.
390+
timer (str): Timing backend; "cpu" (default) or "gpu" (requires accelerator support).
393391
394392
Decorator Examples:
395393
@trace("my_prefix", track_memory=True, timer="gpu")

tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ def gpu_test(gpu_count: int = 1):
1414
required amount of GPU is not available
1515
"""
1616
message = f"Not enough GPUs to run the test: requires {gpu_count}"
17-
local_gpu_count: int = torch.cuda.device_count()
17+
local_gpu_count: int = torch.accelerator.device_count()
1818
return pytest.mark.skipif(local_gpu_count < gpu_count, reason=message)

tests/unit_tests/datasets/test_stop_after_one_epoch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def test_epoch_sync_across_ranks(self):
158158

159159
batch_iter = StopAfterOneEpoch(
160160
iter=iter(dataloader),
161-
device=torch.device("cuda"),
161+
device=torch.accelerator.current_accelerator(),
162162
dp_mesh=dp_mesh,
163163
)
164164

0 commit comments

Comments
 (0)