Multi-Task BabyAI

Multi-Task BabyAI#

  1from argparse import ArgumentParser
  2from functools import partial
  3
  4import wandb
  5import torch
  6from torch import nn
  7import gymnasium as gym
  8import gin
  9
 10import amago
 11from amago import TstepEncoder
 12from amago.envs.builtin.babyai import MultitaskMetaBabyAI, ALL_BABYAI_TASKS
 13from amago.envs import AMAGOEnv
 14from amago.nets.utils import add_activation_log, symlog
 15from amago.cli_utils import *
 16
 17
 18def add_cli(parser):
 19    parser.add_argument(
 20        "--obs_kind",
 21        choices=["partial-grid", "full-grid", "partial-image", "full-image"],
 22        default="partial-grid",
 23    )
 24    parser.add_argument("--k_episodes", type=int, default=2)
 25    parser.add_argument("--train_seeds", type=int, default=5_000)
 26    parser.add_argument("--max_seq_len", type=int, default=512)
 27    return parser
 28
 29
 30TRAIN_TASKS = [
 31    "BabyAI-GoToLocalS7N5-v0",
 32    "BabyAI-GoToObjMaze-v0",
 33    "BabyAI-KeyCorridor-v0",
 34    "BabyAI-KeyCorridorS3R3-v0",
 35    "BabyAI-GoToRedBall-v0",
 36    "BabyAI-KeyCorridorS3R2-v0",
 37    "BabyAI-KeyCorridorS3R1-v0",
 38    "BabyAI-Unlock-v0",
 39    "BabyAI-GoToLocalS8N4-v0",
 40    "BabyAI-GoToObjMazeOpen-v0",
 41    "BabyAI-KeyCorridorS4R3-v0",
 42    "BabyAI-UnlockLocal-v0",
 43    "BabyAI-GoToObjMazeS5-v0",
 44    "BabyAI-GoToObjMazeS4R2-v0",
 45    "BabyAI-GoToLocal-v0",
 46    "BabyAI-PickupLoc-v0",
 47    "BabyAI-UnlockPickup-v0",
 48    "BabyAI-GoTo-v0",
 49    "BabyAI-FindObjS6-v0",
 50    "BabyAI-BlockedUnlockPickup-v0",
 51    "BabyAI-KeyCorridorS5R3-v0",
 52    "BabyAI-GoToObjS6-v0",
 53    "BabyAI-KeyInBox-v0",
 54    "BabyAI-Open-v0",
 55    "BabyAI-GoToOpen-v0",
 56    "BabyAI-GoToDoor-v0",
 57    "BabyAI-FindObjS7-v0",
 58    "BabyAI-OpenRedDoor-v0",
 59    "BabyAI-PickupDist-v0",
 60    "BabyAI-GoToImpUnlock-v0",
 61    "BabyAI-UnblockPickup-v0",
 62    "BabyAI-OpenDoor-v0",
 63    "BabyAI-GoToObjMazeS4-v0",
 64    "BabyAI-OneRoomS12-v0",
 65    "BabyAI-GoToObjMazeS6-v0",
 66    "BabyAI-GoToRedBallNoDists-v0",
 67    "BabyAI-OpenDoorDebug-v0",
 68    "BabyAI-GoToLocalS8N5-v0",
 69    "BabyAI-OneRoomS20-v0",
 70    "BabyAI-Pickup-v0",
 71    "BabyAI-GoToRedBlueBall-v0",
 72    "BabyAI-OpenDoorColor-v0",
 73    "BabyAI-PickupAbove-v0",
 74    "BabyAI-GoToObjDoor-v0",
 75    "BabyAI-OpenRedBlueDoors-v0",
 76    "BabyAI-UnlockToUnlock-v0",
 77    "BabyAI-OneRoomS16-v0",
 78    "BabyAI-GoToLocalS8N6-v0",
 79    "BabyAI-OneRoomS8-v0",
 80    "BabyAI-PickupDistDebug-v0",
 81]
 82TEST_TASKS = ALL_BABYAI_TASKS
 83
 84
 85class BabyAIAMAGOEnv(AMAGOEnv):
 86    def __init__(self, env: gym.Env):
 87        assert isinstance(env, MultitaskMetaBabyAI)
 88        super().__init__(
 89            env=env,
 90        )
 91
 92    @property
 93    def env_name(self):
 94        return self.env.current_task
 95
 96
 97@gin.configurable
 98class BabyTstepEncoder(TstepEncoder):
 99    def __init__(
100        self,
101        obs_space,
102        rl2_space,
103        obs_kind: str = "partial-grid",
104        extras_dim: int = 16,
105        mission_dim: int = 48,
106        emb_dim: int = 300,
107    ):
108        super().__init__(obs_space=obs_space, rl2_space=rl2_space)
109        self.obs_kind = obs_kind
110        if obs_kind in ["partial-image", "full-image"]:
111            cnn_type = amago.nets.cnn.NatureishCNN
112        else:
113            cnn_type = amago.nets.cnn.GridworldCNN
114        self.img_processor = cnn_type(
115            img_shape=obs_space["image"].shape,
116            channels_first=False,
117            activation="leaky_relu",
118        )
119        img_out_dim = self.img_processor(self.img_processor.blank_img).shape[-1]
120
121        low_token = obs_space["mission"].low.min()
122        high_token = obs_space["mission"].high.max()
123        self.mission_processor = amago.nets.goal_embedders.TokenGoalEmb(
124            goal_length=9,
125            goal_dim=1,
126            min_token=low_token,
127            max_token=high_token,
128            goal_emb_dim=mission_dim,
129            embedding_dim=18,
130            hidden_size=96,
131        )
132        self.extras_processor = nn.Sequential(
133            nn.Linear(obs_space["extra"].shape[-1] + rl2_space.shape[-1], 32),
134            nn.LeakyReLU(),
135            nn.Linear(32, extras_dim),
136            nn.LeakyReLU(),
137        )
138        self.out = nn.Linear(img_out_dim + mission_dim + extras_dim, emb_dim)
139        self.out_norm = amago.nets.ff.Normalization("layer", emb_dim)
140        self._emb_dim = emb_dim
141
142    @property
143    def emb_dim(self):
144        return self._emb_dim
145
146    def inner_forward(self, obs, rl2s, log_dict=None):
147        rl2s = symlog(rl2s)
148        extras = torch.cat((rl2s, obs["extra"]), dim=-1)
149        extras_rep = self.extras_processor(extras)
150        add_activation_log("encoder-extras-rep", extras_rep, log_dict)
151        mission_rep = self.mission_processor(obs["mission"].unsqueeze(-1))
152        add_activation_log("encoder-mission-rep", extras_rep, log_dict)
153        img_rep = self.img_processor(obs["image"])
154        add_activation_log("encoder-img-rep", extras_rep, log_dict)
155        merged_rep = torch.cat((img_rep, mission_rep, extras_rep), dim=-1)
156        out = self.out_norm(self.out(merged_rep))
157        return out
158
159
160if __name__ == "__main__":
161    parser = ArgumentParser()
162    add_common_cli(parser)
163    add_cli(parser)
164    args = parser.parse_args()
165
166    config = {
167        "amago.nets.actor_critic.NCriticsTwoHot.min_return": None,
168        "amago.nets.actor_critic.NCriticsTwoHot.max_return": None,
169        "amago.nets.actor_critic.NCriticsTwoHot.output_bins": 32,
170        "BabyTstepEncoder.obs_kind": args.obs_kind,
171    }
172    traj_encoder_type = switch_traj_encoder(
173        config,
174        arch=args.traj_encoder,
175        memory_size=args.memory_size,
176        layers=args.memory_layers,
177    )
178    exploration_type = switch_exploration(config, "egreedy", steps_anneal=500_000)
179    agent_type = switch_agent(config, args.agent_type, reward_multiplier=1000.0)
180    use_config(config, args.configs)
181
182    make_train_env = lambda: BabyAIAMAGOEnv(
183        MultitaskMetaBabyAI(
184            task_names=TRAIN_TASKS,
185            seed_range=(0, args.train_seeds),
186            k_episodes=args.k_episodes,
187            observation_type=args.obs_kind,
188        )
189    )
190
191    make_val_env = lambda: BabyAIAMAGOEnv(
192        MultitaskMetaBabyAI(
193            task_names=TEST_TASKS,
194            seed_range=(args.train_seeds + 1, 1_000_000),
195            k_episodes=args.k_episodes,
196            observation_type=args.obs_kind,
197        )
198    )
199
200    group_name = f"{args.run_name}_babyai_{args.obs_kind}"
201    for trial in range(args.trials):
202        run_name = group_name + f"_trial_{trial}"
203        experiment = create_experiment_from_cli(
204            args,
205            make_train_env=make_train_env,
206            make_val_env=make_val_env,
207            max_seq_len=args.max_seq_len,
208            traj_save_len=args.max_seq_len * 3,
209            stagger_traj_file_lengths=True,
210            run_name=run_name,
211            tstep_encoder_type=BabyTstepEncoder,
212            traj_encoder_type=traj_encoder_type,
213            exploration_wrapper_type=exploration_type,
214            agent_type=agent_type,
215            group_name=group_name,
216            val_timesteps_per_epoch=6000,
217            save_trajs_as="npz",
218        )
219        switch_async_mode(experiment, args.mode)
220        experiment.start()
221        if args.ckpt is not None:
222            experiment.load_checkpoint(args.ckpt)
223        experiment.learn()
224        experiment.evaluate_test(make_val_env, timesteps=20_000, render=False)
225        experiment.delete_buffer_from_disk()
226        wandb.finish()