Hi, I pulled the latest version and the previous issue was solved. Here is the new issue I encountered:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[1], [line 25](vscode-notebook-cell:?execution_count=1&line=25)
21 action_keys = jax.random.split(action_key, 10)
23 actions = jax.vmap(env.action_space(env_params).sample)(action_keys)
---> [25](vscode-notebook-cell:?execution_count=1&line=25) obs, env_state, reward, done, info = env.step(step_keys, env_state, actions, env_params)
[... skipping hidden 16 frame]
File ~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/pixel.py:164, in PixelCraftaxVecEnvWrapper.step(self, key, state, action, params)
156 @functools.partial(jax.jit, static_argnums=(0,-1))
157 def step(
158 self,
(...)
162 params: Optional[environment.EnvParams] = None,
163 ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
--> [164](https://file+.vscode-resource.vscode-cdn.net/Users/liuzhishuai/Experiments/porcsl/Notebooks/~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/pixel.py:164) obs_obj, env_state, reward, done, info = self._env.step(
165 key, state, action, params
166 )
167 image_obs = obs_obj.obs
168 image_obs = self.get_obs(image_obs, self.normalize)
[... skipping hidden 7 frame]
File ~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/observation.py:17, in NamedObservationWrapper.step(self, key, state, action, params)
16 def step(self, key, state, action, params=None):
---> [17](https://file+.vscode-resource.vscode-cdn.net/Users/liuzhishuai/Experiments/porcsl/Notebooks/~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/observation.py:17) obs, env_state, reward, done, info = self._env.step(
18 key, state, action, params
19 )
20 return Observation(obs=obs), env_state, reward, done, info
[... skipping hidden 16 frame]
File ~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/gymnax.py:154, in LogWrapper.step(self, key, state, action, params)
146 @partial(jax.jit, static_argnums=(0, -1))
147 def step(
148 self,
(...)
152 params: Optional[environment.EnvParams] = None,
153 ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
--> [154](https://file+.vscode-resource.vscode-cdn.net/Users/liuzhishuai/Experiments/porcsl/Notebooks/~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/gymnax.py:154) obs, env_state, reward, done, info = self._env.step(
155 key, state.env_state, action, params
156 )
157 new_episode_return = state.episode_returns + reward
158 new_discounted_episode_return = state.discounted_episode_returns + (self.gamma ** state.episode_lengths) * reward
[... skipping hidden 16 frame]
File ~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/gymnax.py:293, in AutoResetEnvWrapper.step(self, rng, state, action, params)
289 obs = jax.lax.select(done, obs_re, obs_st)
291 return obs, state
--> [293](https://file+.vscode-resource.vscode-cdn.net/Users/liuzhishuai/Experiments/porcsl/Notebooks/~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/gymnax.py:293) obs, state = auto_reset(done, state_re, state_st, obs_re, obs_st)
295 return obs, state, reward, done, info
File ~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/gymnax.py:286, in AutoResetEnvWrapper.step.<locals>.auto_reset(done, state_re, state_st, obs_re, obs_st)
285 def auto_reset(done, state_re, state_st, obs_re, obs_st):
--> [286](https://file+.vscode-resource.vscode-cdn.net/Users/liuzhishuai/Experiments/porcsl/Notebooks/~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/gymnax.py:286) state = jax.tree_map(
287 lambda x, y: jax.lax.select(done, x, y), state_re, state_st
288 )
289 obs = jax.lax.select(done, obs_re, obs_st)
291 return obs, state
File ~/miniconda3/envs/porcsl/lib/python3.10/site-packages/jax/_src/deprecations.py:54, in deprecation_getattr.<locals>.getattr(name)
52 message, fn = deprecations[name]
53 if fn is None: # Is the deprecation accelerated?
---> [54](https://file+.vscode-resource.vscode-cdn.net/Users/liuzhishuai/Experiments/porcsl/Notebooks/~/miniconda3/envs/porcsl/lib/python3.10/site-packages/jax/_src/deprecations.py:54) raise AttributeError(message)
55 warnings.warn(message, DeprecationWarning, stacklevel=2)
56 return fn
AttributeError: jax.tree_map was removed in JAX v0.6.0: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
Hi, I pulled the latest version and the previous issue was solved. Here is the new issue I encountered:
I run the code
import jax from pobax.envs import get_env rand_key = jax.random.PRNGKey(2025) env_key, rand_key = jax.random.split(rand_key) env, env_params = get_env("craftax_pixels", env_key) reset_key, rand_key = jax.random.split(rand_key) reset_keys = jax.random.split(rand_key, 10) obs, env_state = env.reset(reset_keys, env_params)and I got the following error
I am using jax==0.6.2 as documented in the requirement.txt
pip freeze output for reproducibility: