1from argparse import ArgumentParser
2import os
3
4os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
5
6import jax
7import wandb
8import torch
9from torch import nn
10from torch.nn import functional as F
11from einops import rearrange
12import gin
13
14import amago
15from amago.envs import AMAGOEnv
16from amago.envs.builtin.xland_minigrid import XLandMinigridVectorizedGym
17from amago.nets.utils import add_activation_log, symlog
18from amago.cli_utils import *
19
20
21def add_cli(parser):
22 parser.add_argument(
23 "--benchmark",
24 type=str,
25 default="small-1m",
26 choices=["trivial-1m", "small-1m", "medium-1m", "high-1m", "high-3m"],
27 )
28 parser.add_argument("--k_shots", type=int, default=15)
29 parser.add_argument("--rooms", type=int, default=1)
30 parser.add_argument("--grid_size", type=int, default=9)
31 parser.add_argument("--max_seq_len", type=int, default=2048)
32 return parser
33
34
35class XLandMiniGridAMAGO(AMAGOEnv):
36 def __init__(self, env: XLandMinigridVectorizedGym):
37 assert isinstance(env, XLandMinigridVectorizedGym)
38 super().__init__(
39 env=env,
40 env_name=f"XLandMiniGrid-{env.ruleset_benchmark}-R{env.rooms}-{env.grid_size}x{env.grid_size}",
41 batched_envs=env.parallel_envs,
42 )
43
44
45@gin.configurable
46class XLandMGTstepEncoder(amago.TstepEncoder):
47 def __init__(
48 self,
49 obs_space,
50 rl2_space,
51 grid_id_dim: int = 8,
52 grid_emb_dim: int = 128,
53 goal_id_dim: int = 8,
54 goal_emb_dim: int = 32,
55 ff_dim: int = 256,
56 out_dim: int = 128,
57 ):
58 super().__init__(obs_space=obs_space, rl2_space=rl2_space)
59
60 # grid world embedding
61 num_tokens = lambda space: (space.high.max() - space.low.min() + 1).item()
62 grid_tokens = num_tokens(obs_space["grid"])
63 self.grid_embedding = nn.Embedding(grid_tokens, embedding_dim=grid_id_dim)
64 self.grid_processor = amago.nets.cnn.GridworldCNN(
65 img_shape=obs_space["grid"].shape,
66 channels_first=False,
67 activation="leaky_relu",
68 channels=[32, 48, 64],
69 )
70 grid_out_dim = self.grid_processor(self.grid_processor.blank_img).shape[-1]
71 self.grid_rep_ff = nn.Linear(grid_out_dim, grid_emb_dim)
72
73 # goal token embedding
74 goal_tokens = num_tokens(obs_space["goal"])
75 self.goal_embedding = nn.Embedding(goal_tokens, embedding_dim=goal_id_dim)
76 goal_inp_dim = goal_id_dim * obs_space["goal"].shape[0]
77 self.goal_rep_ff = nn.Sequential(
78 nn.Linear(goal_inp_dim, goal_inp_dim),
79 nn.LeakyReLU(),
80 nn.Linear(goal_inp_dim, goal_emb_dim),
81 )
82
83 # merge grid, goal, and other array features
84 self.merge = nn.Sequential(
85 nn.Linear(
86 grid_emb_dim + goal_emb_dim + 5 + self.rl2_space.shape[-1], ff_dim
87 ),
88 nn.LeakyReLU(),
89 nn.Linear(ff_dim, out_dim),
90 )
91 self.out_norm = amago.nets.ff.Normalization("layer", out_dim)
92 self.out_dim = out_dim
93
94 @property
95 def emb_dim(self):
96 return self.out_dim
97
98 def inner_forward(self, obs, rl2s, log_dict=None):
99 grid_rep = self.grid_embedding(obs["grid"].long())
100 grid_rep = rearrange(grid_rep, "... h w layers emb -> ... h w (layers emb)")
101 grid_rep = self.grid_processor(obs["grid"])
102 add_activation_log("encoder-grid-rep", grid_rep, log_dict)
103 grid_rep = F.leaky_relu(self.grid_rep_ff(grid_rep))
104 add_activation_log("encoder-grid-rep-ff", grid_rep, log_dict)
105
106 goal_rep = self.goal_embedding(obs["goal"].long())
107 goal_rep = rearrange(goal_rep, "... length emb -> ... (length emb)")
108 goal_rep = F.leaky_relu(self.goal_rep_ff(goal_rep))
109 add_activation_log("encoder-goal-rep-ff", grid_rep, log_dict)
110
111 extras = torch.cat((obs["direction_done"], symlog(rl2s)), dim=-1)
112 merged_rep = torch.cat((grid_rep, goal_rep, extras), dim=-1)
113 merged_rep = self.merge(merged_rep)
114 add_activation_log("encoder-merged-rep", merged_rep, log_dict)
115 out = self.out_norm(merged_rep)
116 return out
117
118
119if __name__ == "__main__":
120 parser = ArgumentParser()
121 add_common_cli(parser)
122 add_cli(parser)
123 args = parser.parse_args()
124
125 config = {
126 "amago.envs.exploration.EpsilonGreedy.steps_anneal": 1_000_000,
127 "amago.nets.actor_critic.NCriticsTwoHot.min_return": -args.k_shots * 10.0 * 10,
128 "amago.nets.actor_critic.NCriticsTwoHot.max_return": args.k_shots * 10.0 * 10,
129 "amago.nets.actor_critic.NCriticsTwoHot.output_bins": 32,
130 }
131
132 traj_encoder_type = switch_traj_encoder(
133 config,
134 arch=args.traj_encoder,
135 memory_size=args.memory_size,
136 layers=args.memory_layers,
137 )
138 agent_type = switch_agent(config, args.agent_type, reward_multiplier=10.0)
139 use_config(config, args.configs)
140
141 xland_kwargs = {
142 "parallel_envs": args.parallel_actors,
143 "rooms": args.rooms,
144 "grid_size": args.grid_size,
145 "k_shots": args.k_shots,
146 "ruleset_benchmark": args.benchmark,
147 }
148
149 args.env_mode = "already_vectorized"
150 make_train_env = lambda: XLandMiniGridAMAGO(
151 XLandMinigridVectorizedGym(**xland_kwargs, train_test_split="train"),
152 )
153 make_val_env = lambda: XLandMiniGridAMAGO(
154 XLandMinigridVectorizedGym(**xland_kwargs, train_test_split="test"),
155 )
156 with jax.default_device(jax.devices("cpu")[0]):
157 traj_len = make_train_env().suggested_max_seq_len
158
159 group_name = f"{args.run_name}_xlandmg_{args.benchmark}_R{args.rooms}_{args.grid_size}x{args.grid_size}"
160 args.start_learning_at_epoch = traj_len // args.timesteps_per_epoch
161 args.max_seq_len = min(args.max_seq_len, traj_len)
162
163 for trial in range(args.trials):
164 run_name = group_name + f"_trial_{trial}"
165 experiment = create_experiment_from_cli(
166 args,
167 make_train_env=make_train_env,
168 make_val_env=make_val_env,
169 max_seq_len=args.max_seq_len,
170 traj_save_len=traj_len,
171 stagger_traj_file_lengths=False,
172 run_name=run_name,
173 tstep_encoder_type=XLandMGTstepEncoder,
174 traj_encoder_type=traj_encoder_type,
175 agent_type=agent_type,
176 group_name=group_name,
177 val_timesteps_per_epoch=traj_len,
178 save_trajs_as="npz-compressed",
179 grad_clip=2.0,
180 )
181 switch_async_mode(experiment, args.mode)
182 amago_device = experiment.DEVICE.index or torch.cuda.current_device()
183 env_device = jax.devices("gpu")[amago_device]
184 with jax.default_device(env_device):
185 experiment.start()
186 if args.ckpt is not None:
187 experiment.load_checkpoint(args.ckpt)
188 experiment.learn()
189 experiment.evaluate_test(make_val_env, timesteps=20_000, render=False)
190 experiment.delete_buffer_from_disk()
191 wandb.finish()