66
77import logging
88import random
9+ from collections import deque
910from dataclasses import dataclass
11+ from operator import itemgetter
1012from typing import Any , Callable
1113
1214from monarch .actor import endpoint
1921logger .setLevel (logging .INFO )
2022
2123
24+ @dataclass
25+ class BufferEntry :
26+ data : "Episode"
27+ sample_count : int = 0
28+
29+
30+ def age_evict (
31+ buffer : deque , policy_version : int , max_samples : int = None , max_age : int = None
32+ ) -> list [int ]:
33+ """Buffer eviction policy, remove old or over-sampled entries"""
34+ indices = []
35+ for i , entry in enumerate (buffer ):
36+ if max_age and policy_version - entry .data .policy_version > max_age :
37+ continue
38+ if max_samples and entry .sample_count >= max_samples :
39+ continue
40+ indices .append (i )
41+ return indices
42+
43+
44+ def random_sample (buffer : deque , sample_size : int , policy_version : int ) -> list [int ]:
45+ """Buffer random sampling policy"""
46+ if sample_size > len (buffer ):
47+ return None
48+ return random .sample (range (len (buffer )), k = sample_size )
49+
50+
2251@dataclass
2352class ReplayBuffer (ForgeActor ):
2453 """Simple in-memory replay buffer implementation."""
2554
2655 batch_size : int
27- max_policy_age : int
2856 dp_size : int = 1
57+ max_policy_age : int | None = None
58+ max_buffer_size : int | None = None
59+ max_resample_count : int | None = 0
2960 seed : int | None = None
3061 collate : Callable = lambda batch : batch
31-
32- def __post_init__ (self ):
33- super ().__init__ ()
62+ eviction_policy : Callable = age_evict
63+ sample_policy : Callable = random_sample
3464
3565 @endpoint
3666 async def setup (self ) -> None :
37- self .buffer : list = []
67+ self .buffer : deque = deque ( maxlen = self . max_buffer_size )
3868 if self .seed is None :
3969 self .seed = random .randint (0 , 2 ** 32 )
4070 random .seed (self .seed )
41- self .sampler = random .sample
4271
4372 @endpoint
4473 async def add (self , episode : "Episode" ) -> None :
45- self .buffer .append (episode )
74+ self .buffer .append (BufferEntry ( episode ) )
4675 record_metric ("buffer/add/count_episodes_added" , 1 , Reduce .SUM )
4776
4877 @endpoint
4978 @trace ("buffer_perf/sample" , track_memory = False )
5079 async def sample (
51- self , curr_policy_version : int , batch_size : int | None = None
80+ self , curr_policy_version : int
5281 ) -> tuple [tuple [Any , ...], ...] | None :
5382 """Sample from the replay buffer.
5483
5584 Args:
5685 curr_policy_version (int): The current policy version.
57- batch_size (int, optional): Number of episodes to sample. If none, defaults to batch size
58- passed in at initialization.
5986
6087 Returns:
6188 A list of sampled episodes with shape (dp_size, bsz, ...) or None if there are not enough episodes in the buffer.
6289 """
6390 # Record sample request metric
6491 record_metric ("buffer/sample/count_sample_requests" , 1 , Reduce .SUM )
6592
66- bsz = batch_size if batch_size is not None else self .batch_size
67- total_samples = self .dp_size * bsz
93+ total_samples = self .dp_size * self .batch_size
6894
69- # Evict old episodes
95+ # Evict episodes
7096 self ._evict (curr_policy_version )
7197
72- if total_samples > len (self .buffer ):
73- return None
74-
75- # Calculate buffer utilization
76- utilization_pct = (
77- (total_samples / len (self .buffer )) * 100 if len (self .buffer ) > 0 else 0
78- )
79-
80- record_metric (
81- "buffer/sample/avg_buffer_utilization" ,
82- len (self .buffer ),
83- Reduce .MEAN ,
84- )
85-
86- record_metric (
87- "buffer/sample/avg_buffer_utilization_pct" ,
88- utilization_pct ,
89- Reduce .MEAN ,
90- )
98+ # Calculate metrics
99+ if len (self .buffer ) > 0 :
100+ record_metric (
101+ "buffer/sample/avg_data_utilization" ,
102+ total_samples / len (self .buffer ),
103+ Reduce .MEAN ,
104+ )
105+ if self .max_buffer_size :
106+ record_metric (
107+ "buffer/sample/avg_buffer_utilization" ,
108+ len (self .buffer ) / self .max_buffer_size ,
109+ Reduce .MEAN ,
110+ )
91111
92112 # TODO: prefetch samples in advance
93- idx_to_sample = self .sampler (range (len (self .buffer )), k = total_samples )
94- # Pop episodes in descending order to avoid shifting issues
95- popped = [self .buffer .pop (i ) for i in sorted (idx_to_sample , reverse = True )]
96-
97- # Reorder popped episodes to match the original random sample order
98- sorted_idxs = sorted (idx_to_sample , reverse = True )
99- idx_to_popped = dict (zip (sorted_idxs , popped ))
100- sampled_episodes = [idx_to_popped [i ] for i in idx_to_sample ]
113+ sampled_indices = self .sample_policy (
114+ self .buffer , total_samples , curr_policy_version
115+ )
116+ if sampled_indices is None :
117+ return None
118+ sampled_episodes = []
119+ for entry in self ._collect (sampled_indices ):
120+ entry .sample_count += 1
121+ sampled_episodes .append (entry .data )
101122
102123 # Reshape into (dp_size, bsz, ...)
103124 reshaped_episodes = [
104- sampled_episodes [dp_idx * bsz : (dp_idx + 1 ) * bsz ]
125+ sampled_episodes [dp_idx * self . batch_size : (dp_idx + 1 ) * self . batch_size ]
105126 for dp_idx in range (self .dp_size )
106127 ]
107128
@@ -118,46 +139,69 @@ async def evict(self, curr_policy_version: int) -> None:
118139 """
119140 self ._evict (curr_policy_version )
120141
121- def _evict (self , curr_policy_version : int ) -> None :
142+ def _evict (self , curr_policy_version ) :
122143 buffer_len_before_evict = len (self .buffer )
123- self .buffer = [
124- trajectory
125- for trajectory in self .buffer
126- if (curr_policy_version - trajectory .policy_version ) <= self .max_policy_age
127- ]
128- buffer_len_after_evict = len (self .buffer )
144+ indices = self .eviction_policy (
145+ self .buffer ,
146+ curr_policy_version ,
147+ self .max_resample_count + 1 ,
148+ self .max_policy_age ,
149+ )
150+ self .buffer = deque (self ._collect (indices ))
129151
130152 # Record evict metrics
131- policy_staleness = [
132- curr_policy_version - ep .policy_version for ep in self .buffer
153+ policy_age = [
154+ curr_policy_version - ep .data . policy_version for ep in self .buffer
133155 ]
134- if policy_staleness :
156+ if policy_age :
135157 record_metric (
136- "buffer/evict/avg_policy_staleness " ,
137- sum (policy_staleness ) / len (policy_staleness ),
158+ "buffer/evict/avg_policy_age " ,
159+ sum (policy_age ) / len (policy_age ),
138160 Reduce .MEAN ,
139161 )
140162 record_metric (
141- "buffer/evict/max_policy_staleness " ,
142- max (policy_staleness ),
163+ "buffer/evict/max_policy_age " ,
164+ max (policy_age ),
143165 Reduce .MAX ,
144166 )
145167
146- # Record eviction metrics
147- evicted_count = buffer_len_before_evict - buffer_len_after_evict
148- if evicted_count > 0 :
149- record_metric (
150- "buffer/evict/sum_episodes_evicted" , evicted_count , Reduce .SUM
151- )
168+ evicted_count = buffer_len_before_evict - len (self .buffer )
169+ record_metric ("buffer/evict/sum_episodes_evicted" , evicted_count , Reduce .SUM )
152170
153171 logger .debug (
154172 f"maximum policy age: { self .max_policy_age } , current policy version: { curr_policy_version } , "
155- f"{ evicted_count } episodes expired, { buffer_len_after_evict } episodes left"
173+ f"{ evicted_count } episodes expired, { len ( self . buffer ) } episodes left"
156174 )
157175
176+ def _collect (self , indices : list [int ]):
177+ """Efficiently traverse deque and collect elements at each requested index"""
178+ n = len (self .buffer )
179+ if n == 0 or len (indices ) == 0 :
180+ return []
181+
182+ # Normalize indices and store with their original order
183+ indexed = [(pos , idx % n ) for pos , idx in enumerate (indices )]
184+ indexed .sort (key = itemgetter (1 ))
185+
186+ result = [None ] * len (indices )
187+ rotations = 0 # logical current index
188+ total_rotation = 0 # total net rotation applied
189+
190+ for orig_pos , idx in indexed :
191+ move = idx - rotations
192+ self .buffer .rotate (- move )
193+ total_rotation += move
194+ rotations = idx
195+ result [orig_pos ] = self .buffer [0 ]
196+
197+ # Restore original deque orientation
198+ self .buffer .rotate (total_rotation )
199+
200+ return result
201+
158202 @endpoint
159203 async def _getitem (self , idx : int ):
160- return self .buffer [idx ]
204+ return self .buffer [idx ]. data
161205
162206 @endpoint
163207 async def _numel (self ) -> int :
0 commit comments