XLand Mini-Grid

XLand Mini-Grid#

  1from argparse import ArgumentParser
  2import os
  3
  4os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
  5
  6import jax
  7import wandb
  8import torch
  9from torch import nn
 10from torch.nn import functional as F
 11from einops import rearrange
 12import gin
 13
 14import amago
 15from amago.envs import AMAGOEnv
 16from amago.envs.builtin.xland_minigrid import XLandMinigridVectorizedGym
 17from amago.nets.utils import add_activation_log, symlog
 18from amago.cli_utils import *
 19
 20
 21def add_cli(parser):
 22    parser.add_argument(
 23        "--benchmark",
 24        type=str,
 25        default="small-1m",
 26        choices=["trivial-1m", "small-1m", "medium-1m", "high-1m", "high-3m"],
 27    )
 28    parser.add_argument("--k_shots", type=int, default=15)
 29    parser.add_argument("--rooms", type=int, default=1)
 30    parser.add_argument("--grid_size", type=int, default=9)
 31    parser.add_argument("--max_seq_len", type=int, default=2048)
 32    return parser
 33
 34
 35class XLandMiniGridAMAGO(AMAGOEnv):
 36    def __init__(self, env: XLandMinigridVectorizedGym):
 37        assert isinstance(env, XLandMinigridVectorizedGym)
 38        super().__init__(
 39            env=env,
 40            env_name=f"XLandMiniGrid-{env.ruleset_benchmark}-R{env.rooms}-{env.grid_size}x{env.grid_size}",
 41            batched_envs=env.parallel_envs,
 42        )
 43
 44
 45@gin.configurable
 46class XLandMGTstepEncoder(amago.TstepEncoder):
 47    def __init__(
 48        self,
 49        obs_space,
 50        rl2_space,
 51        grid_id_dim: int = 8,
 52        grid_emb_dim: int = 128,
 53        goal_id_dim: int = 8,
 54        goal_emb_dim: int = 32,
 55        ff_dim: int = 256,
 56        out_dim: int = 128,
 57    ):
 58        super().__init__(obs_space=obs_space, rl2_space=rl2_space)
 59
 60        # grid world embedding
 61        num_tokens = lambda space: (space.high.max() - space.low.min() + 1).item()
 62        grid_tokens = num_tokens(obs_space["grid"])
 63        self.grid_embedding = nn.Embedding(grid_tokens, embedding_dim=grid_id_dim)
 64        self.grid_processor = amago.nets.cnn.GridworldCNN(
 65            img_shape=obs_space["grid"].shape,
 66            channels_first=False,
 67            activation="leaky_relu",
 68            channels=[32, 48, 64],
 69        )
 70        grid_out_dim = self.grid_processor(self.grid_processor.blank_img).shape[-1]
 71        self.grid_rep_ff = nn.Linear(grid_out_dim, grid_emb_dim)
 72
 73        # goal token embedding
 74        goal_tokens = num_tokens(obs_space["goal"])
 75        self.goal_embedding = nn.Embedding(goal_tokens, embedding_dim=goal_id_dim)
 76        goal_inp_dim = goal_id_dim * obs_space["goal"].shape[0]
 77        self.goal_rep_ff = nn.Sequential(
 78            nn.Linear(goal_inp_dim, goal_inp_dim),
 79            nn.LeakyReLU(),
 80            nn.Linear(goal_inp_dim, goal_emb_dim),
 81        )
 82
 83        # merge grid, goal, and other array features
 84        self.merge = nn.Sequential(
 85            nn.Linear(
 86                grid_emb_dim + goal_emb_dim + 5 + self.rl2_space.shape[-1], ff_dim
 87            ),
 88            nn.LeakyReLU(),
 89            nn.Linear(ff_dim, out_dim),
 90        )
 91        self.out_norm = amago.nets.ff.Normalization("layer", out_dim)
 92        self.out_dim = out_dim
 93
 94    @property
 95    def emb_dim(self):
 96        return self.out_dim
 97
 98    def inner_forward(self, obs, rl2s, log_dict=None):
 99        grid_rep = self.grid_embedding(obs["grid"].long())
100        grid_rep = rearrange(grid_rep, "... h w layers emb -> ... h w (layers emb)")
101        grid_rep = self.grid_processor(obs["grid"])
102        add_activation_log("encoder-grid-rep", grid_rep, log_dict)
103        grid_rep = F.leaky_relu(self.grid_rep_ff(grid_rep))
104        add_activation_log("encoder-grid-rep-ff", grid_rep, log_dict)
105
106        goal_rep = self.goal_embedding(obs["goal"].long())
107        goal_rep = rearrange(goal_rep, "... length emb -> ... (length emb)")
108        goal_rep = F.leaky_relu(self.goal_rep_ff(goal_rep))
109        add_activation_log("encoder-goal-rep-ff", grid_rep, log_dict)
110
111        extras = torch.cat((obs["direction_done"], symlog(rl2s)), dim=-1)
112        merged_rep = torch.cat((grid_rep, goal_rep, extras), dim=-1)
113        merged_rep = self.merge(merged_rep)
114        add_activation_log("encoder-merged-rep", merged_rep, log_dict)
115        out = self.out_norm(merged_rep)
116        return out
117
118
119if __name__ == "__main__":
120    parser = ArgumentParser()
121    add_common_cli(parser)
122    add_cli(parser)
123    args = parser.parse_args()
124
125    config = {
126        "amago.envs.exploration.EpsilonGreedy.steps_anneal": 1_000_000,
127        "amago.nets.actor_critic.NCriticsTwoHot.min_return": -args.k_shots * 10.0 * 10,
128        "amago.nets.actor_critic.NCriticsTwoHot.max_return": args.k_shots * 10.0 * 10,
129        "amago.nets.actor_critic.NCriticsTwoHot.output_bins": 32,
130    }
131
132    traj_encoder_type = switch_traj_encoder(
133        config,
134        arch=args.traj_encoder,
135        memory_size=args.memory_size,
136        layers=args.memory_layers,
137    )
138    agent_type = switch_agent(config, args.agent_type, reward_multiplier=10.0)
139    use_config(config, args.configs)
140
141    xland_kwargs = {
142        "parallel_envs": args.parallel_actors,
143        "rooms": args.rooms,
144        "grid_size": args.grid_size,
145        "k_shots": args.k_shots,
146        "ruleset_benchmark": args.benchmark,
147    }
148
149    args.env_mode = "already_vectorized"
150    make_train_env = lambda: XLandMiniGridAMAGO(
151        XLandMinigridVectorizedGym(**xland_kwargs, train_test_split="train"),
152    )
153    make_val_env = lambda: XLandMiniGridAMAGO(
154        XLandMinigridVectorizedGym(**xland_kwargs, train_test_split="test"),
155    )
156    with jax.default_device(jax.devices("cpu")[0]):
157        traj_len = make_train_env().suggested_max_seq_len
158
159    group_name = f"{args.run_name}_xlandmg_{args.benchmark}_R{args.rooms}_{args.grid_size}x{args.grid_size}"
160    args.start_learning_at_epoch = traj_len // args.timesteps_per_epoch
161    args.max_seq_len = min(args.max_seq_len, traj_len)
162
163    for trial in range(args.trials):
164        run_name = group_name + f"_trial_{trial}"
165        experiment = create_experiment_from_cli(
166            args,
167            make_train_env=make_train_env,
168            make_val_env=make_val_env,
169            max_seq_len=args.max_seq_len,
170            traj_save_len=traj_len,
171            stagger_traj_file_lengths=False,
172            run_name=run_name,
173            tstep_encoder_type=XLandMGTstepEncoder,
174            traj_encoder_type=traj_encoder_type,
175            agent_type=agent_type,
176            group_name=group_name,
177            val_timesteps_per_epoch=traj_len,
178            save_trajs_as="npz-compressed",
179            grad_clip=2.0,
180        )
181        switch_async_mode(experiment, args.mode)
182        amago_device = experiment.DEVICE.index or torch.cuda.current_device()
183        env_device = jax.devices("gpu")[amago_device]
184        with jax.default_device(env_device):
185            experiment.start()
186            if args.ckpt is not None:
187                experiment.load_checkpoint(args.ckpt)
188            experiment.learn()
189            experiment.evaluate_test(make_val_env, timesteps=20_000, render=False)
190            experiment.delete_buffer_from_disk()
191            wandb.finish()