1"""
2Support for gymnax is experimental and mainly meant to test the already_vectorized
3env API used by XLand MiniGrid (an unsolved environment) with classic gym envs.
4Many of the gymnax envs appear to be broken by recent versions of jax.
5There are a couple memory/meta-RL bsuite envs where AMAGO+Transformer
6is significantly better than the gymnax reference scores though.
7"""
8
9import os
10
11# stop jax from stealing pytorch's memory, since we're only using it for the envs
12os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
13
14from argparse import ArgumentParser
15import math
16from functools import partial
17
18import gymnax
19import torch
20import jax
21import wandb
22import numpy as np
23
24from amago.envs import AMAGOEnv
25from amago.envs.builtin.gymnax_envs import GymnaxCompatibility
26from amago.nets.cnn import GridworldCNN
27from amago.cli_utils import *
28
29
30def add_cli(parser):
31 parser.add_argument("--env", type=str, required=True)
32 parser.add_argument("--max_seq_len", type=int, default=128)
33 parser.add_argument("--eval_timesteps", type=int, default=1000)
34 return parser
35
36
37def make_gymnax_amago(env_name, parallel_envs):
38 env, params = gymnax.make(env_name)
39 vec_env = GymnaxCompatibility(env, num_envs=parallel_envs, params=params)
40 # when the environment is already vectorized, alert the AMAGOEnv wrapper with `batched_envs`
41 return AMAGOEnv(
42 env=vec_env, env_name=f"gymnax_{env_name}", batched_envs=parallel_envs
43 )
44
45
46def guess_tstep_encoder(config, obs_shape):
47 """
48 We'll move past the somewhat random collection of gymnax envs by making up a simple
49 timestep encoder based on a few hacks. If we really cared about gymnax performance we
50 could tune this per environment.
51 """
52 if len(obs_shape) == 3:
53 print(f"Guessing CNN for observation of shape {obs_shape}")
54 channels_first = np.argmin(obs_shape).item() == 0
55 return switch_tstep_encoder(
56 config,
57 "cnn",
58 cnn_type=GridworldCNN,
59 channels_first=channels_first,
60 )
61 else:
62 print(f"Guessing MLP for observation of shape {obs_shape}")
63 dim = math.prod(obs_shape) # FFTstepEncoder will flatten the obs on input
64 return switch_tstep_encoder(
65 config,
66 "ff",
67 d_hidden=max(dim // 3, 128),
68 n_layers=2,
69 d_output=max(dim // 4, 96),
70 )
71
72
73if __name__ == "__main__":
74 parser = ArgumentParser()
75 add_common_cli(parser)
76 add_cli(parser)
77 args = parser.parse_args()
78
79 # "already_vectorized" will stop the training loop from trying spawn multiple instances of the env
80 args.env_mode = "already_vectorized"
81
82 # config
83 config = {}
84 traj_encoder_type = switch_traj_encoder(
85 config,
86 arch=args.traj_encoder,
87 memory_size=args.memory_size,
88 layers=args.memory_layers,
89 )
90 with jax.default_device(jax.devices("cpu")[0]):
91 test_env, env_params = gymnax.make(args.env)
92 test_obs_shape = test_env.observation_space(env_params).shape
93 tstep_encoder_type = guess_tstep_encoder(config, test_obs_shape)
94 agent_type = switch_agent(config, args.agent_type)
95
96 use_config(config, args.configs)
97 make_env = partial(
98 make_gymnax_amago, env_name=args.env, parallel_envs=args.parallel_actors
99 )
100 group_name = f"{args.run_name}_{args.env}"
101 for trial in range(args.trials):
102 run_name = group_name + f"_trial_{trial}"
103 experiment = create_experiment_from_cli(
104 args,
105 make_train_env=make_env,
106 make_val_env=make_env,
107 max_seq_len=args.max_seq_len,
108 traj_save_len=args.max_seq_len * 20,
109 run_name=run_name,
110 agent_type=agent_type,
111 tstep_encoder_type=tstep_encoder_type,
112 traj_encoder_type=traj_encoder_type,
113 group_name=group_name,
114 val_timesteps_per_epoch=args.eval_timesteps,
115 grad_clip=2.0,
116 l2_coeff=1e-4,
117 save_trajs_as="npz-compressed",
118 )
119 experiment = switch_async_mode(experiment, args.mode)
120 amago_device = experiment.DEVICE.index or torch.cuda.current_device()
121 env_device = jax.devices("gpu")[amago_device]
122 with jax.default_device(env_device):
123 experiment.start()
124 if args.ckpt is not None:
125 experiment.load_checkpoint(args.ckpt)
126 experiment.learn()
127 experiment.evaluate_test(make_env, timesteps=10_000, render=False)
128 experiment.delete_buffer_from_disk()
129 wandb.finish()