66
77"""Test for data/replay_buffer.py"""
88
9+ from dataclasses import dataclass
10+
911import pytest
1012import pytest_asyncio
1113from forge .actors .replay_buffer import ReplayBuffer
12- from forge .types import Trajectory
14+
15+
16+ @dataclass
17+ class TestEpisode :
18+ """
19+ Dummy Episode containing just a policy version
20+
21+ ReplayBuffer expects any construct (typically an Episode) that contains a
22+ `policy_version`.
23+
24+ TODO: Replaced with a unified interface in the future.
25+ """
26+
27+ policy_version : int
1328
1429
1530class TestReplayBuffer :
@@ -23,27 +38,27 @@ async def replay_buffer(self) -> ReplayBuffer:
2338
2439 @pytest .mark .asyncio
2540 async def test_add (self , replay_buffer : ReplayBuffer ) -> None :
26- trajectory = Trajectory (policy_version = 0 )
27- await replay_buffer .add .call_one (trajectory )
41+ episode = TestEpisode (policy_version = 0 )
42+ await replay_buffer .add .call_one (episode )
2843 assert replay_buffer ._numel .call_one ().get () == 1
29- assert replay_buffer ._getitem .call_one (0 ).get () == trajectory
44+ assert replay_buffer ._getitem .call_one (0 ).get () == episode
3045 replay_buffer .clear .call_one ().get ()
3146
3247 @pytest .mark .asyncio
3348 async def test_add_multiple (self , replay_buffer ) -> None :
34- trajectory_0 = Trajectory (policy_version = 0 )
35- trajectory_1 = Trajectory (policy_version = 1 )
36- await replay_buffer .add .call_one (trajectory_0 )
37- await replay_buffer .add .call_one (trajectory_1 )
49+ episode_0 = TestEpisode (policy_version = 0 )
50+ episode_1 = TestEpisode (policy_version = 1 )
51+ await replay_buffer .add .call_one (episode_0 )
52+ await replay_buffer .add .call_one (episode_1 )
3853 assert replay_buffer ._numel .call_one ().get () == 2
39- assert replay_buffer ._getitem .call_one (0 ).get () == trajectory_0
40- assert replay_buffer ._getitem .call_one (1 ).get () == trajectory_1
54+ assert replay_buffer ._getitem .call_one (0 ).get () == episode_0
55+ assert replay_buffer ._getitem .call_one (1 ).get () == episode_1
4156 replay_buffer .clear .call_one ().get ()
4257
4358 @pytest .mark .asyncio
4459 async def test_state_dict_save_load (self , replay_buffer ) -> None :
45- trajectory = Trajectory (policy_version = 0 )
46- await replay_buffer .add .call_one (trajectory )
60+ episode = TestEpisode (policy_version = 0 )
61+ await replay_buffer .add .call_one (episode )
4762 state_dict = replay_buffer .state_dict .call_one ().get ()
4863 replay_buffer .clear .call_one ().get ()
4964 assert replay_buffer ._numel .call_one ().get () == 0
@@ -53,21 +68,21 @@ async def test_state_dict_save_load(self, replay_buffer) -> None:
5368
5469 @pytest .mark .asyncio
5570 async def test_evict (self , replay_buffer ) -> None :
56- trajectory_0 = Trajectory (policy_version = 0 )
57- trajectory_1 = Trajectory (policy_version = 1 )
58- await replay_buffer .add .call_one (trajectory_0 )
59- await replay_buffer .add .call_one (trajectory_1 )
71+ episode_0 = TestEpisode (policy_version = 0 )
72+ episode_1 = TestEpisode (policy_version = 1 )
73+ await replay_buffer .add .call_one (episode_0 )
74+ await replay_buffer .add .call_one (episode_1 )
6075 assert replay_buffer ._numel .call_one ().get () == 2
6176 await replay_buffer .evict .call_one (curr_policy_version = 2 )
6277 assert replay_buffer ._numel .call_one ().get () == 1
6378 replay_buffer .clear .call_one ().get ()
6479
6580 @pytest .mark .asyncio
6681 async def test_sample (self , replay_buffer ) -> None :
67- trajectory_0 = Trajectory (policy_version = 0 )
68- trajectory_1 = Trajectory (policy_version = 1 )
69- await replay_buffer .add .call_one (trajectory_0 )
70- await replay_buffer .add .call_one (trajectory_1 )
82+ episode_0 = TestEpisode (policy_version = 0 )
83+ episode_1 = TestEpisode (policy_version = 1 )
84+ await replay_buffer .add .call_one (episode_0 )
85+ await replay_buffer .add .call_one (episode_1 )
7186 assert replay_buffer ._numel .call_one ().get () == 2
7287
7388 # Test a simple sampling
@@ -77,19 +92,19 @@ async def test_sample(self, replay_buffer) -> None:
7792 assert replay_buffer ._numel .call_one ().get () == 2
7893
7994 # Test sampling (not enough samples in buffer, returns None)
80- await replay_buffer .add .call_one (trajectory_0 )
95+ await replay_buffer .add .call_one (episode_0 )
8196 samples = await replay_buffer .sample .call_one (curr_policy_version = 1 )
8297 assert samples is None
8398 replay_buffer .clear .call_one ().get ()
8499
85100 @pytest .mark .asyncio
86101 async def test_sample_with_evictions (self , replay_buffer ) -> None :
87- trajectory_0 = Trajectory (policy_version = 0 )
88- trajectory_1 = Trajectory (policy_version = 1 )
89- trajectory_2 = Trajectory (policy_version = 2 )
90- await replay_buffer .add .call_one (trajectory_0 )
91- await replay_buffer .add .call_one (trajectory_1 )
92- await replay_buffer .add .call_one (trajectory_2 )
102+ episode_0 = TestEpisode (policy_version = 0 )
103+ episode_1 = TestEpisode (policy_version = 1 )
104+ episode_2 = TestEpisode (policy_version = 2 )
105+ await replay_buffer .add .call_one (episode_0 )
106+ await replay_buffer .add .call_one (episode_1 )
107+ await replay_buffer .add .call_one (episode_2 )
93108 assert replay_buffer ._numel .call_one ().get () == 3
94109 samples = await replay_buffer .sample .call_one (
95110 curr_policy_version = 2 ,
@@ -112,8 +127,8 @@ async def test_sample_dp_size(self) -> None:
112127
113128 # Add enough trajectories to sample
114129 for i in range (10 ):
115- trajectory = Trajectory (policy_version = 0 )
116- await replay_buffer .add .call_one (trajectory )
130+ episode = TestEpisode (policy_version = 0 )
131+ await replay_buffer .add .call_one (episode )
117132
118133 # Sample and verify len(samples) == dp_size
119134 samples = await replay_buffer .sample .call_one (curr_policy_version = 0 )
0 commit comments