1from argparse import ArgumentParser
2from functools import partial
3
4import wandb
5import torch
6from torch import nn
7import gymnasium as gym
8import gin
9
10import amago
11from amago import TstepEncoder
12from amago.envs.builtin.babyai import MultitaskMetaBabyAI, ALL_BABYAI_TASKS
13from amago.envs import AMAGOEnv
14from amago.nets.utils import add_activation_log, symlog
15from amago.cli_utils import *
16
17
18def add_cli(parser):
19 parser.add_argument(
20 "--obs_kind",
21 choices=["partial-grid", "full-grid", "partial-image", "full-image"],
22 default="partial-grid",
23 )
24 parser.add_argument("--k_episodes", type=int, default=2)
25 parser.add_argument("--train_seeds", type=int, default=5_000)
26 parser.add_argument("--max_seq_len", type=int, default=512)
27 return parser
28
29
30TRAIN_TASKS = [
31 "BabyAI-GoToLocalS7N5-v0",
32 "BabyAI-GoToObjMaze-v0",
33 "BabyAI-KeyCorridor-v0",
34 "BabyAI-KeyCorridorS3R3-v0",
35 "BabyAI-GoToRedBall-v0",
36 "BabyAI-KeyCorridorS3R2-v0",
37 "BabyAI-KeyCorridorS3R1-v0",
38 "BabyAI-Unlock-v0",
39 "BabyAI-GoToLocalS8N4-v0",
40 "BabyAI-GoToObjMazeOpen-v0",
41 "BabyAI-KeyCorridorS4R3-v0",
42 "BabyAI-UnlockLocal-v0",
43 "BabyAI-GoToObjMazeS5-v0",
44 "BabyAI-GoToObjMazeS4R2-v0",
45 "BabyAI-GoToLocal-v0",
46 "BabyAI-PickupLoc-v0",
47 "BabyAI-UnlockPickup-v0",
48 "BabyAI-GoTo-v0",
49 "BabyAI-FindObjS6-v0",
50 "BabyAI-BlockedUnlockPickup-v0",
51 "BabyAI-KeyCorridorS5R3-v0",
52 "BabyAI-GoToObjS6-v0",
53 "BabyAI-KeyInBox-v0",
54 "BabyAI-Open-v0",
55 "BabyAI-GoToOpen-v0",
56 "BabyAI-GoToDoor-v0",
57 "BabyAI-FindObjS7-v0",
58 "BabyAI-OpenRedDoor-v0",
59 "BabyAI-PickupDist-v0",
60 "BabyAI-GoToImpUnlock-v0",
61 "BabyAI-UnblockPickup-v0",
62 "BabyAI-OpenDoor-v0",
63 "BabyAI-GoToObjMazeS4-v0",
64 "BabyAI-OneRoomS12-v0",
65 "BabyAI-GoToObjMazeS6-v0",
66 "BabyAI-GoToRedBallNoDists-v0",
67 "BabyAI-OpenDoorDebug-v0",
68 "BabyAI-GoToLocalS8N5-v0",
69 "BabyAI-OneRoomS20-v0",
70 "BabyAI-Pickup-v0",
71 "BabyAI-GoToRedBlueBall-v0",
72 "BabyAI-OpenDoorColor-v0",
73 "BabyAI-PickupAbove-v0",
74 "BabyAI-GoToObjDoor-v0",
75 "BabyAI-OpenRedBlueDoors-v0",
76 "BabyAI-UnlockToUnlock-v0",
77 "BabyAI-OneRoomS16-v0",
78 "BabyAI-GoToLocalS8N6-v0",
79 "BabyAI-OneRoomS8-v0",
80 "BabyAI-PickupDistDebug-v0",
81]
82TEST_TASKS = ALL_BABYAI_TASKS
83
84
85class BabyAIAMAGOEnv(AMAGOEnv):
86 def __init__(self, env: gym.Env):
87 assert isinstance(env, MultitaskMetaBabyAI)
88 super().__init__(
89 env=env,
90 )
91
92 @property
93 def env_name(self):
94 return self.env.current_task
95
96
97@gin.configurable
98class BabyTstepEncoder(TstepEncoder):
99 def __init__(
100 self,
101 obs_space,
102 rl2_space,
103 obs_kind: str = "partial-grid",
104 extras_dim: int = 16,
105 mission_dim: int = 48,
106 emb_dim: int = 300,
107 ):
108 super().__init__(obs_space=obs_space, rl2_space=rl2_space)
109 self.obs_kind = obs_kind
110 if obs_kind in ["partial-image", "full-image"]:
111 cnn_type = amago.nets.cnn.NatureishCNN
112 else:
113 cnn_type = amago.nets.cnn.GridworldCNN
114 self.img_processor = cnn_type(
115 img_shape=obs_space["image"].shape,
116 channels_first=False,
117 activation="leaky_relu",
118 )
119 img_out_dim = self.img_processor(self.img_processor.blank_img).shape[-1]
120
121 low_token = obs_space["mission"].low.min()
122 high_token = obs_space["mission"].high.max()
123 self.mission_processor = amago.nets.goal_embedders.TokenGoalEmb(
124 goal_length=9,
125 goal_dim=1,
126 min_token=low_token,
127 max_token=high_token,
128 goal_emb_dim=mission_dim,
129 embedding_dim=18,
130 hidden_size=96,
131 )
132 self.extras_processor = nn.Sequential(
133 nn.Linear(obs_space["extra"].shape[-1] + rl2_space.shape[-1], 32),
134 nn.LeakyReLU(),
135 nn.Linear(32, extras_dim),
136 nn.LeakyReLU(),
137 )
138 self.out = nn.Linear(img_out_dim + mission_dim + extras_dim, emb_dim)
139 self.out_norm = amago.nets.ff.Normalization("layer", emb_dim)
140 self._emb_dim = emb_dim
141
142 @property
143 def emb_dim(self):
144 return self._emb_dim
145
146 def inner_forward(self, obs, rl2s, log_dict=None):
147 rl2s = symlog(rl2s)
148 extras = torch.cat((rl2s, obs["extra"]), dim=-1)
149 extras_rep = self.extras_processor(extras)
150 add_activation_log("encoder-extras-rep", extras_rep, log_dict)
151 mission_rep = self.mission_processor(obs["mission"].unsqueeze(-1))
152 add_activation_log("encoder-mission-rep", extras_rep, log_dict)
153 img_rep = self.img_processor(obs["image"])
154 add_activation_log("encoder-img-rep", extras_rep, log_dict)
155 merged_rep = torch.cat((img_rep, mission_rep, extras_rep), dim=-1)
156 out = self.out_norm(self.out(merged_rep))
157 return out
158
159
160if __name__ == "__main__":
161 parser = ArgumentParser()
162 add_common_cli(parser)
163 add_cli(parser)
164 args = parser.parse_args()
165
166 config = {
167 "amago.nets.actor_critic.NCriticsTwoHot.min_return": None,
168 "amago.nets.actor_critic.NCriticsTwoHot.max_return": None,
169 "amago.nets.actor_critic.NCriticsTwoHot.output_bins": 32,
170 "BabyTstepEncoder.obs_kind": args.obs_kind,
171 }
172 traj_encoder_type = switch_traj_encoder(
173 config,
174 arch=args.traj_encoder,
175 memory_size=args.memory_size,
176 layers=args.memory_layers,
177 )
178 exploration_type = switch_exploration(config, "egreedy", steps_anneal=500_000)
179 agent_type = switch_agent(config, args.agent_type, reward_multiplier=1000.0)
180 use_config(config, args.configs)
181
182 make_train_env = lambda: BabyAIAMAGOEnv(
183 MultitaskMetaBabyAI(
184 task_names=TRAIN_TASKS,
185 seed_range=(0, args.train_seeds),
186 k_episodes=args.k_episodes,
187 observation_type=args.obs_kind,
188 )
189 )
190
191 make_val_env = lambda: BabyAIAMAGOEnv(
192 MultitaskMetaBabyAI(
193 task_names=TEST_TASKS,
194 seed_range=(args.train_seeds + 1, 1_000_000),
195 k_episodes=args.k_episodes,
196 observation_type=args.obs_kind,
197 )
198 )
199
200 group_name = f"{args.run_name}_babyai_{args.obs_kind}"
201 for trial in range(args.trials):
202 run_name = group_name + f"_trial_{trial}"
203 experiment = create_experiment_from_cli(
204 args,
205 make_train_env=make_train_env,
206 make_val_env=make_val_env,
207 max_seq_len=args.max_seq_len,
208 traj_save_len=args.max_seq_len * 3,
209 stagger_traj_file_lengths=True,
210 run_name=run_name,
211 tstep_encoder_type=BabyTstepEncoder,
212 traj_encoder_type=traj_encoder_type,
213 exploration_wrapper_type=exploration_type,
214 agent_type=agent_type,
215 group_name=group_name,
216 val_timesteps_per_epoch=6000,
217 save_trajs_as="npz",
218 )
219 switch_async_mode(experiment, args.mode)
220 experiment.start()
221 if args.ckpt is not None:
222 experiment.load_checkpoint(args.ckpt)
223 experiment.learn()
224 experiment.evaluate_test(make_val_env, timesteps=20_000, render=False)
225 experiment.delete_buffer_from_disk()
226 wandb.finish()