Skip to content

Commit e116e7e

Browse files
authored
Remove unused interfaces and types (meta-pytorch#446)
1 parent a70c3fa commit e116e7e

File tree

3 files changed

+45
-151
lines changed

3 files changed

+45
-151
lines changed

src/forge/interfaces.py

Lines changed: 1 addition & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,7 @@
77
from abc import ABC, abstractmethod
88
from typing import Any, Mapping
99

10-
from monarch.actor import endpoint
11-
12-
from forge.controller import ForgeActor
13-
14-
from forge.types import Action, Message, Observation, Scalar, State
10+
from forge.types import Message, Observation, Scalar
1511

1612

1713
class Transform(ABC):
@@ -37,63 +33,6 @@ def __call__(self, observation: Observation) -> Observation:
3733
pass
3834

3935

40-
class Environment(ABC):
41-
"""Abstract base class for environments.
42-
43-
Args:
44-
transform: Optional transform that modifies observations, typically to add rewards.
45-
Can be a Transform instance or a callable for backward compatibility.
46-
"""
47-
48-
def __init__(
49-
self,
50-
transform: Transform | None = None,
51-
):
52-
self.transform = transform
53-
54-
@abstractmethod
55-
def reset(self) -> Observation:
56-
"""Reset the environment and return an initial observation."""
57-
pass
58-
59-
@abstractmethod
60-
def step(self, action: Any) -> Observation:
61-
"""Take a step in the environment and return an observation."""
62-
pass
63-
64-
@property
65-
@abstractmethod
66-
def state(self) -> State:
67-
"""Get the current state of the environment."""
68-
pass
69-
70-
def _apply_transform(self, observation: Observation) -> Observation:
71-
"""Apply the transform to an observation if one is provided."""
72-
if self.transform is not None:
73-
return self.transform(observation)
74-
return observation
75-
76-
77-
class Policy(ForgeActor, ABC):
78-
"""Abstract interface for policies."""
79-
80-
@endpoint
81-
@abstractmethod
82-
async def generate(self, request: Observation) -> Action:
83-
"""Generate an action given a state/request."""
84-
pass
85-
86-
@endpoint
87-
@abstractmethod
88-
async def update_weights(self, policy_version: int):
89-
"""Update the policy weights.
90-
91-
Args:
92-
policy_version: The version number to update to.
93-
"""
94-
pass
95-
96-
9736
class BaseTokenizer(ABC):
9837
"""
9938
Abstract token encoding model that implements ``encode`` and ``decode`` methods.
@@ -210,10 +149,3 @@ class Reward(ABC):
210149
def __call__(self, observation: Observation) -> float:
211150
"""Compute a reward for an observation."""
212151
pass
213-
214-
215-
# TODO
216-
# class RLLoss(ABC):
217-
218-
# class SFTLoss(ABC): # inherit from titan loss
219-
# from torchtitan.components.loss import LossFunction

src/forge/types.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,6 @@ class Message(TypedDict):
1515
tools: dict[str, Any] | None
1616

1717

18-
@dataclass
19-
class ForgeEnvInfo:
20-
"""Environment info returned with observations."""
21-
22-
episode_id: str | None = None
23-
step_count: int = 0
24-
metadata: dict | None = None
25-
26-
2718
@dataclass(kw_only=True)
2819
class Observation:
2920
"""Base class for environment observations.
@@ -44,50 +35,6 @@ class Observation:
4435
metadata: dict[str, Any] = field(default_factory=dict)
4536

4637

47-
@dataclass(kw_only=True)
48-
class Action:
49-
"""Base class for environment actions.
50-
51-
Contract:
52-
- Should contain all information needed to execute a step in the environment
53-
- Should be serializable/deserializable
54-
- Should be immutable (or treated as such)
55-
56-
Args:
57-
metadata: Additional data that may be useful for logging, debugging, or transforms
58-
"""
59-
60-
metadata: dict[str, Any] = field(default_factory=dict)
61-
62-
63-
@dataclass
64-
class Trajectory:
65-
"""A trajectory containing a sequence of states, actions, etc."""
66-
67-
policy_version: int
68-
states: list[Observation] = field(default_factory=list)
69-
actions: list[Action] = field(default_factory=list)
70-
71-
def __post_init__(self):
72-
assert self.policy_version >= 0
73-
74-
75-
@dataclass(kw_only=True)
76-
class State:
77-
"""Base class for environment state.
78-
79-
Contract:
80-
- Should contain all information needed to restore the environment
81-
- Should be serializable/deserializable
82-
- May contain information not exposed in observations
83-
84-
Args:
85-
metadata: Additional state information that may be useful for debugging or analysis
86-
"""
87-
88-
metadata: dict[str, Any] = field(default_factory=dict)
89-
90-
9138
class Launcher(Enum):
9239
MAST = "mast"
9340
SLURM = "slurm"

tests/unit_tests/test_replay_buffer.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,25 @@
66

77
"""Test for data/replay_buffer.py"""
88

9+
from dataclasses import dataclass
10+
911
import pytest
1012
import pytest_asyncio
1113
from forge.actors.replay_buffer import ReplayBuffer
12-
from forge.types import Trajectory
14+
15+
16+
@dataclass
17+
class TestEpisode:
18+
"""
19+
Dummy Episode containing just a policy version
20+
21+
ReplayBuffer expects any construct (typically an Episode) that contains a
22+
`policy_version`.
23+
24+
TODO: Replaced with a unified interface in the future.
25+
"""
26+
27+
policy_version: int
1328

1429

1530
class TestReplayBuffer:
@@ -23,27 +38,27 @@ async def replay_buffer(self) -> ReplayBuffer:
2338

2439
@pytest.mark.asyncio
2540
async def test_add(self, replay_buffer: ReplayBuffer) -> None:
26-
trajectory = Trajectory(policy_version=0)
27-
await replay_buffer.add.call_one(trajectory)
41+
episode = TestEpisode(policy_version=0)
42+
await replay_buffer.add.call_one(episode)
2843
assert replay_buffer._numel.call_one().get() == 1
29-
assert replay_buffer._getitem.call_one(0).get() == trajectory
44+
assert replay_buffer._getitem.call_one(0).get() == episode
3045
replay_buffer.clear.call_one().get()
3146

3247
@pytest.mark.asyncio
3348
async def test_add_multiple(self, replay_buffer) -> None:
34-
trajectory_0 = Trajectory(policy_version=0)
35-
trajectory_1 = Trajectory(policy_version=1)
36-
await replay_buffer.add.call_one(trajectory_0)
37-
await replay_buffer.add.call_one(trajectory_1)
49+
episode_0 = TestEpisode(policy_version=0)
50+
episode_1 = TestEpisode(policy_version=1)
51+
await replay_buffer.add.call_one(episode_0)
52+
await replay_buffer.add.call_one(episode_1)
3853
assert replay_buffer._numel.call_one().get() == 2
39-
assert replay_buffer._getitem.call_one(0).get() == trajectory_0
40-
assert replay_buffer._getitem.call_one(1).get() == trajectory_1
54+
assert replay_buffer._getitem.call_one(0).get() == episode_0
55+
assert replay_buffer._getitem.call_one(1).get() == episode_1
4156
replay_buffer.clear.call_one().get()
4257

4358
@pytest.mark.asyncio
4459
async def test_state_dict_save_load(self, replay_buffer) -> None:
45-
trajectory = Trajectory(policy_version=0)
46-
await replay_buffer.add.call_one(trajectory)
60+
episode = TestEpisode(policy_version=0)
61+
await replay_buffer.add.call_one(episode)
4762
state_dict = replay_buffer.state_dict.call_one().get()
4863
replay_buffer.clear.call_one().get()
4964
assert replay_buffer._numel.call_one().get() == 0
@@ -53,21 +68,21 @@ async def test_state_dict_save_load(self, replay_buffer) -> None:
5368

5469
@pytest.mark.asyncio
5570
async def test_evict(self, replay_buffer) -> None:
56-
trajectory_0 = Trajectory(policy_version=0)
57-
trajectory_1 = Trajectory(policy_version=1)
58-
await replay_buffer.add.call_one(trajectory_0)
59-
await replay_buffer.add.call_one(trajectory_1)
71+
episode_0 = TestEpisode(policy_version=0)
72+
episode_1 = TestEpisode(policy_version=1)
73+
await replay_buffer.add.call_one(episode_0)
74+
await replay_buffer.add.call_one(episode_1)
6075
assert replay_buffer._numel.call_one().get() == 2
6176
await replay_buffer.evict.call_one(curr_policy_version=2)
6277
assert replay_buffer._numel.call_one().get() == 1
6378
replay_buffer.clear.call_one().get()
6479

6580
@pytest.mark.asyncio
6681
async def test_sample(self, replay_buffer) -> None:
67-
trajectory_0 = Trajectory(policy_version=0)
68-
trajectory_1 = Trajectory(policy_version=1)
69-
await replay_buffer.add.call_one(trajectory_0)
70-
await replay_buffer.add.call_one(trajectory_1)
82+
episode_0 = TestEpisode(policy_version=0)
83+
episode_1 = TestEpisode(policy_version=1)
84+
await replay_buffer.add.call_one(episode_0)
85+
await replay_buffer.add.call_one(episode_1)
7186
assert replay_buffer._numel.call_one().get() == 2
7287

7388
# Test a simple sampling
@@ -77,19 +92,19 @@ async def test_sample(self, replay_buffer) -> None:
7792
assert replay_buffer._numel.call_one().get() == 2
7893

7994
# Test sampling (not enough samples in buffer, returns None)
80-
await replay_buffer.add.call_one(trajectory_0)
95+
await replay_buffer.add.call_one(episode_0)
8196
samples = await replay_buffer.sample.call_one(curr_policy_version=1)
8297
assert samples is None
8398
replay_buffer.clear.call_one().get()
8499

85100
@pytest.mark.asyncio
86101
async def test_sample_with_evictions(self, replay_buffer) -> None:
87-
trajectory_0 = Trajectory(policy_version=0)
88-
trajectory_1 = Trajectory(policy_version=1)
89-
trajectory_2 = Trajectory(policy_version=2)
90-
await replay_buffer.add.call_one(trajectory_0)
91-
await replay_buffer.add.call_one(trajectory_1)
92-
await replay_buffer.add.call_one(trajectory_2)
102+
episode_0 = TestEpisode(policy_version=0)
103+
episode_1 = TestEpisode(policy_version=1)
104+
episode_2 = TestEpisode(policy_version=2)
105+
await replay_buffer.add.call_one(episode_0)
106+
await replay_buffer.add.call_one(episode_1)
107+
await replay_buffer.add.call_one(episode_2)
93108
assert replay_buffer._numel.call_one().get() == 3
94109
samples = await replay_buffer.sample.call_one(
95110
curr_policy_version=2,
@@ -112,8 +127,8 @@ async def test_sample_dp_size(self) -> None:
112127

113128
# Add enough trajectories to sample
114129
for i in range(10):
115-
trajectory = Trajectory(policy_version=0)
116-
await replay_buffer.add.call_one(trajectory)
130+
episode = TestEpisode(policy_version=0)
131+
await replay_buffer.add.call_one(episode)
117132

118133
# Sample and verify len(samples) == dp_size
119134
samples = await replay_buffer.sample.call_one(curr_policy_version=0)

0 commit comments

Comments
 (0)