Skip to content

Another issue from Craftax #35

@zhishuailiu

Description

@zhishuailiu

Hi, I pulled the latest version and the previous issue was solved. Here is the new issue I encountered:

I run the code

import jax
from pobax.envs import get_env

rand_key = jax.random.PRNGKey(2025)
env_key, rand_key = jax.random.split(rand_key)

env, env_params = get_env("craftax_pixels", env_key)

reset_key, rand_key = jax.random.split(rand_key)
reset_keys = jax.random.split(rand_key, 10)

obs, env_state = env.reset(reset_keys, env_params)

and I got the following error

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[1], [line 25](vscode-notebook-cell:?execution_count=1&line=25)
     21 action_keys = jax.random.split(action_key, 10)
     23 actions = jax.vmap(env.action_space(env_params).sample)(action_keys)
---> [25](vscode-notebook-cell:?execution_count=1&line=25) obs, env_state, reward, done, info = env.step(step_keys, env_state, actions, env_params)

    [... skipping hidden 16 frame]

File ~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/pixel.py:164, in PixelCraftaxVecEnvWrapper.step(self, key, state, action, params)
    156 @functools.partial(jax.jit, static_argnums=(0,-1))
    157 def step(
    158         self,
   (...)
    162         params: Optional[environment.EnvParams] = None,
    163 ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
--> [164](https://file+.vscode-resource.vscode-cdn.net/Users/liuzhishuai/Experiments/porcsl/Notebooks/~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/pixel.py:164)     obs_obj, env_state, reward, done, info = self._env.step(
    165         key, state, action, params
    166     )
    167     image_obs = obs_obj.obs
    168     image_obs = self.get_obs(image_obs, self.normalize)

    [... skipping hidden 7 frame]

File ~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/observation.py:17, in NamedObservationWrapper.step(self, key, state, action, params)
     16 def step(self, key, state, action, params=None):
---> [17](https://file+.vscode-resource.vscode-cdn.net/Users/liuzhishuai/Experiments/porcsl/Notebooks/~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/observation.py:17)     obs, env_state, reward, done, info = self._env.step(
     18         key, state, action, params
     19     )
     20     return Observation(obs=obs), env_state, reward, done, info

    [... skipping hidden 16 frame]

File ~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/gymnax.py:154, in LogWrapper.step(self, key, state, action, params)
    146 @partial(jax.jit, static_argnums=(0, -1))
    147 def step(
    148         self,
   (...)
    152         params: Optional[environment.EnvParams] = None,
    153 ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
--> [154](https://file+.vscode-resource.vscode-cdn.net/Users/liuzhishuai/Experiments/porcsl/Notebooks/~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/gymnax.py:154)     obs, env_state, reward, done, info = self._env.step(
    155         key, state.env_state, action, params
    156     )
    157     new_episode_return = state.episode_returns + reward
    158     new_discounted_episode_return = state.discounted_episode_returns + (self.gamma ** state.episode_lengths) * reward

    [... skipping hidden 16 frame]

File ~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/gymnax.py:293, in AutoResetEnvWrapper.step(self, rng, state, action, params)
    289     obs = jax.lax.select(done, obs_re, obs_st)
    291     return obs, state
--> [293](https://file+.vscode-resource.vscode-cdn.net/Users/liuzhishuai/Experiments/porcsl/Notebooks/~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/gymnax.py:293) obs, state = auto_reset(done, state_re, state_st, obs_re, obs_st)
    295 return obs, state, reward, done, info

File ~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/gymnax.py:286, in AutoResetEnvWrapper.step.<locals>.auto_reset(done, state_re, state_st, obs_re, obs_st)
    285 def auto_reset(done, state_re, state_st, obs_re, obs_st):
--> [286](https://file+.vscode-resource.vscode-cdn.net/Users/liuzhishuai/Experiments/porcsl/Notebooks/~/miniconda3/envs/porcsl/lib/python3.10/site-packages/pobax/envs/wrappers/gymnax.py:286)     state = jax.tree_map(
    287         lambda x, y: jax.lax.select(done, x, y), state_re, state_st
    288     )
    289     obs = jax.lax.select(done, obs_re, obs_st)
    291     return obs, state

File ~/miniconda3/envs/porcsl/lib/python3.10/site-packages/jax/_src/deprecations.py:54, in deprecation_getattr.<locals>.getattr(name)
     52 message, fn = deprecations[name]
     53 if fn is None:  # Is the deprecation accelerated?
---> [54](https://file+.vscode-resource.vscode-cdn.net/Users/liuzhishuai/Experiments/porcsl/Notebooks/~/miniconda3/envs/porcsl/lib/python3.10/site-packages/jax/_src/deprecations.py:54)   raise AttributeError(message)
     55 warnings.warn(message, DeprecationWarning, stacklevel=2)
     56 return fn

AttributeError: jax.tree_map was removed in JAX v0.6.0: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).

I am using jax==0.6.2 as documented in the requirement.txt

pip freeze output for reproducibility:

absl-py==2.3.1
aiofiles==25.1.0
annotated-types==0.7.0
anyio==4.12.0
astroid==4.0.2
attrs==25.4.0
black==25.11.0
blinker==1.9.0
brax==0.13.0
certifi==2025.11.12
charset-normalizer==3.4.4
chex==0.1.90
click==8.3.1
cloudpickle==3.1.2
contourpy==1.3.2
craftax==1.5.0
cycler==0.12.1
decorator==5.2.1
dill==0.4.0
distrax==0.1.5
dm-env==1.6
dm-tree==0.1.9
docstring_parser==0.17.0
esquilax==2.1.0
etils==1.13.0
exceptiongroup==1.3.1
Farama-Notifications==0.0.4
filelock==3.20.0
flake8==7.3.0
Flask==3.1.2
flask-cors==6.0.1
flax==0.10.7
fonttools==4.61.0
fsspec==2025.10.0
gast==0.7.0
gitdb==4.0.12
GitPython==3.1.45
glfw==2.10.0
gymnasium==1.2.2
gymnax==0.0.9
h11==0.16.0
hf-xet==1.2.0
httpcore==1.0.9
httpx==0.28.1
huggingface_hub==1.1.7
humanize==4.14.0
idna==3.11
ImageIO==2.37.2
importlib_resources==6.5.2
iniconfig==2.3.0
ipympl==0.9.8
ipywidgets==8.1.8
isort==7.0.0
itsdangerous==2.2.0
jax==0.6.2
jax-tqdm==0.4.0
jaxlib==0.6.2
jaxopt==0.8.5
Jinja2==3.1.6
jumanji==1.1.1
jupyterlab_widgets==3.0.16
kiwisolver==1.4.9
markdown-it-py==4.0.0
MarkupSafe==3.0.3
matplotlib==3.10.7
mccabe==0.7.0
mdurl==0.1.2
ml_collections==1.1.0
ml_dtypes==0.5.4
mpmath==1.3.0
msgpack==1.1.2
mujoco==3.3.7
mujoco-mjx==3.3.7
mypy_extensions==1.1.0
Navix==0.7.4
nest-asyncio==1.6.0
networkx==3.4.2
numpy==2.2.6
opt_einsum==3.4.0
optax==0.2.6
orbax-checkpoint==0.11.30
packaging==25.0
pandas==2.3.3
pillow==12.0.0
platformdirs==4.5.0
pluggy==1.6.0
pobax @ git+https://github.com/taodav/pobax.git@0180159c780aa391351c864e03b882f36fbcf9b4
protobuf==6.33.1
psutil==7.1.3
pycodestyle==2.14.0
pydantic==2.12.5
pydantic_core==2.41.5
pyflakes==3.4.0
pygame==2.6.1
Pygments==2.19.2
pylint==4.0.4
PyOpenGL==3.1.10
pyparsing==3.2.5
pytest==9.0.1
python-dateutil==2.9.0.post0
pytokens==0.3.0
pytz==2025.2
PyYAML==6.0.3
requests==2.32.5
rich==14.2.0
rlax==0.1.7
scipy==1.15.3
seaborn==0.13.2
sentry-sdk==2.46.0
setuptools-scm==9.2.2
shellingham==1.5.4
shtab==1.8.0
simplejson==3.20.2
six==1.17.0
smmap==5.0.2
sniffio==1.3.1
sympy==1.14.0
tensorboardX==2.6.4
tensorflow-probability==0.25.0
tensorstore==0.1.78
tomli==2.3.0
tomlkit==0.13.3
toolz==1.1.0
torch==2.9.1
tqdm==4.67.1
treescope==0.1.10
trimesh==4.10.0
typed-argument-parser==1.11.0
typeguard==4.4.4
typer-slim==0.20.0
typing-inspect==0.9.0
typing-inspection==0.4.2
typing_extensions==4.15.0
tyro==0.9.35
tzdata==2025.2
urllib3==2.5.0
wandb==0.23.0
warp==1.0.4
Werkzeug==3.1.4
widgetsnbextension==4.0.15
wrapt==2.0.1
zipp==3.23.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions