@@ -54,18 +54,18 @@ class Tracer:
5454 Tracer with multi-step timing and optional memory tracking at start/stop boundaries.
5555 Steps only affect timing; memory is tracked from start() to stop().
5656
57- Supports non-blocking CUDA timing via CUDA events and background polling threads.
57+ Supports non-blocking accelerator timing via torch events and background polling threads.
5858 Aggregation is handled externally by the metrics system via record_metric.
5959
6060 User must call start() and stop() explicitly.
6161 Supports reuse: after calling stop(), you may call start() again to begin a new timing session.
6262
6363 Local env flag DISABLE_PERF_METRICS can be used to skip all timing operations.
64- Local env flag METRIC_TIMER_USES_GPU can be used to set CUDA timing.
64+ Local env flag METRIC_TIMER_USES_GPU can be used to set accelerator timing.
6565
6666 Args:
6767 prefix (str): Prefix for metric names, e.g. "my_prefix" -> "{my_prefix}/{step_name}/duration_avg_s".
68- track_memory (bool): Whether to track CUDA memory usage. Defaults to False.
68+ track_memory (bool): Whether to track accelerator memory usage. Defaults to False.
6969 timer (str): Timing backend; "cpu" (default) or "gpu".
7070
7171 Example:
@@ -138,8 +138,8 @@ def start(self) -> None:
138138 else :
139139 # Env var not set - use the timer parameter
140140 use_gpu = self .time_with_gpu
141- time_with_gpu_events = use_gpu and torch .cuda .is_available ()
142- self ._timer = _TimerCUDA () if time_with_gpu_events else _TimerCPU ()
141+ time_with_gpu_events = use_gpu and torch .accelerator .is_available ()
142+ self ._timer = _TimerGPU () if time_with_gpu_events else _TimerCPU ()
143143 self ._timer .start ()
144144
145145 self ._active = True
@@ -176,7 +176,7 @@ def stop(self) -> None:
176176 def _start_memory_tracking (self ) -> None :
177177 is_outer_scope = not _is_memory_active ()
178178 should_track = (
179- self .track_memory and is_outer_scope and torch .cuda .is_available ()
179+ self .track_memory and is_outer_scope and torch .accelerator .is_available ()
180180 )
181181
182182 if self .track_memory and not is_outer_scope :
@@ -185,23 +185,23 @@ def _start_memory_tracking(self) -> None:
185185
186186 if should_track :
187187 _set_memory_active (True )
188- torch .cuda .reset_peak_memory_stats ()
189- self ._start_mem = torch .cuda .memory_allocated ()
188+ torch .accelerator .reset_peak_memory_stats ()
189+ self ._start_mem = torch .accelerator .memory_allocated ()
190190 self ._memory_started = True
191191
192192 def _stop_memory_tracking (self ) -> None :
193193 if not self ._memory_started :
194194 return
195195
196- end_mem = torch .cuda .memory_allocated ()
196+ end_mem = torch .accelerator .memory_allocated ()
197197 delta = (end_mem - self ._start_mem ) / 1024 ** 3
198- peak_mem = torch .cuda .max_memory_allocated () / 1024 ** 3
198+ peak_mem = torch .accelerator .max_memory_allocated () / 1024 ** 3
199199 record_metric (
200200 f"{ self .prefix } /memory_delta_end_start_avg_gb" , delta , Reduce .MEAN
201201 )
202202 record_metric (f"{ self .prefix } /memory_peak_max_gb" , peak_mem , Reduce .MAX )
203203 _set_memory_active (False )
204- torch .cuda .reset_peak_memory_stats ()
204+ torch .accelerator .reset_peak_memory_stats ()
205205 self ._memory_started = False
206206
207207 def _record_timing_metrics (
@@ -258,12 +258,12 @@ def get_all_durations(self) -> tuple[list[tuple[str, float]], float]:
258258 return self ._durations [:], stop_step_ms
259259
260260
261- class _TimerCUDA (_TimerProtocol ):
262- """CUDA timing backend with non-blocking events and futures.
263- Uses a thread pool to poll CUDA events asynchronously without blocking the main thread.
261+ class _TimerGPU (_TimerProtocol ):
262+ """Accelerator timing backend with non-blocking events and futures.
263+ Uses a thread pool to poll torch events asynchronously without blocking the main thread.
264264
265265 Example:
266- timer = _TimerCUDA ()
266+ timer = _TimerGPU ()
267267 timer.start()
268268 # torch.mm(a, b) # ~100ms GPU
269269 timer.step("matmul")
@@ -272,36 +272,36 @@ class _TimerCUDA(_TimerProtocol):
272272 """
273273
274274 def __init__ (self , max_workers : int = 2 ) -> None :
275- if not torch .cuda .is_available ():
276- raise RuntimeError ("CUDA is not available for timing" )
275+ if not torch .accelerator .is_available ():
276+ raise RuntimeError ("Accelerator is not available for timing" )
277277 self ._executor = ThreadPoolExecutor (max_workers = max_workers )
278278 self ._futures : list [tuple [str , Future [float ], int ]] = (
279279 []
280280 ) # (name, future, submission_index)
281281 self ._durations : list [tuple [str , float ]] = []
282- self ._chain_start : torch .cuda . Event | None = None
282+ self ._chain_start : torch .Event | None = None
283283
284284 def start (self ) -> None :
285285 """Call before any steps. Clear state for reuse; record initial event on current stream."""
286286 self ._futures .clear ()
287287 self ._durations .clear ()
288- stream = torch .cuda .current_stream ()
289- start_event = torch .cuda . Event (enable_timing = True )
288+ stream = torch .accelerator .current_stream ()
289+ start_event = torch .Event (enable_timing = True )
290290 start_event .record (stream )
291291 self ._chain_start = start_event
292292
293293 def step (self , name : str ) -> None :
294294 """Mark the end of a GPU workload segment and start the next, submitting async polling.
295- Records a CUDA end event on the current stream; a background thread polls completion.
295+ Records a torch end event on the current stream; a background thread polls completion.
296296
297297 Args:
298298 name: Label for this segment's duration
299299 """
300300 if self ._chain_start is None :
301301 raise ValueError ("Timer must be started before calling step" )
302302
303- stream = torch .cuda .current_stream ()
304- end_event = torch .cuda . Event (enable_timing = True )
303+ stream = torch .accelerator .current_stream ()
304+ end_event = torch .Event (enable_timing = True )
305305 end_event .record (stream )
306306
307307 future = self ._executor .submit (self ._poll_elapsed , self ._chain_start , end_event )
@@ -312,9 +312,7 @@ def step(self, name: str) -> None:
312312
313313 self ._chain_start = end_event
314314
315- def _poll_elapsed (
316- self , start_event : torch .cuda .Event , end_event : torch .cuda .Event
317- ) -> float :
315+ def _poll_elapsed (self , start_event : torch .Event , end_event : torch .Event ) -> float :
318316 """Compute elapsed time after polling with backoff."""
319317 # Poll until ready
320318 sleep_time = 0.001 # Start at 1ms
@@ -388,8 +386,8 @@ def trace(
388386
389387 Args:
390388 prefix (str): Prefix for metric names
391- track_memory (bool): Whether to track CUDA memory usage. Defaults to False.
392- timer (str): Timing backend; "cpu" (default) or "gpu" (requires CUDA ).
389+ track_memory (bool): Whether to track memory usage. Defaults to False.
390+ timer (str): Timing backend; "cpu" (default) or "gpu" (requires accelerator support ).
393391
394392 Decorator Examples:
395393 @trace("my_prefix", track_memory=True, timer="gpu")
0 commit comments