Meta Frozen Lake

Meta Frozen Lake#

  1from argparse import ArgumentParser
  2
  3import amago
  4from amago.envs.builtin.toy_gym import MetaFrozenLake
  5from amago.envs import AMAGOEnv
  6from amago.loading import DiskTrajDataset
  7from amago import cli_utils
  8
  9
 10def add_cli(parser):
 11    parser.add_argument(
 12        "--seq_model",
 13        type=str,
 14        choices=["ff", "transformer", "rnn", "mamba"],
 15        required=True,
 16    )
 17    parser.add_argument("--run_name", type=str, required=True)
 18    parser.add_argument("--buffer_dir", type=str, required=True)
 19    parser.add_argument("--log", action="store_true")
 20    parser.add_argument("--trials", type=int, default=1)
 21    parser.add_argument("--lake_size", type=int, default=5)
 22    parser.add_argument("--k_episodes", type=int, default=10)
 23    parser.add_argument("--hard_mode", action="store_true")
 24    parser.add_argument("--recover_mode", action="store_true")
 25    parser.add_argument("--slip_chance", type=float, default=0.0)
 26    parser.add_argument(
 27        "--max_episode_steps",
 28        type=int,
 29        default=None,
 30        help="Max steps per attempt. Default: N² (standard) or 2*N² (hard).",
 31    )
 32    parser.add_argument(
 33        "--hide_k_progress",
 34        action="store_true",
 35        help="Hide current_k/k_episodes from observations (for length extrapolation tests).",
 36    )
 37    parser.add_argument(
 38        "--max_seq_len",
 39        type=int,
 40        default=None,
 41        help="Training sequence length. Default: max_episode_steps * k_episodes (full trajectory).",
 42    )
 43    return parser
 44
 45
 46if __name__ == "__main__":
 47    parser = ArgumentParser()
 48    add_cli(parser)
 49    args = parser.parse_args()
 50
 51    if args.log:
 52        import wandb
 53
 54    lake_kwargs = dict(
 55        size=args.lake_size,
 56        k_episodes=args.k_episodes,
 57        hard_mode=args.hard_mode,
 58        recover_mode=args.recover_mode,
 59        max_episode_steps=args.max_episode_steps,
 60        show_k_progress=not args.hide_k_progress,
 61        slip_chance=args.slip_chance,
 62    )
 63    max_ep_steps = MetaFrozenLake(**lake_kwargs).max_episode_steps
 64    max_rollout_length = max_ep_steps * args.k_episodes
 65    max_seq_len = args.max_seq_len or max_rollout_length
 66
 67    config = {}
 68    # configure trajectory encoder (seq2seq memory model)
 69    traj_encoder_type = cli_utils.switch_traj_encoder(
 70        config,
 71        arch=args.seq_model,
 72        memory_size=128,
 73        layers=3,
 74    )
 75    # configure timestep encoder
 76    tstep_encoder_type = cli_utils.switch_tstep_encoder(
 77        config, arch="ff", n_layers=1, d_hidden=128, d_output=64, normalize_inputs=False
 78    )
 79    # we're using the default exploration strategy but being overly verbose about it for the example
 80    exploration_wrapper_type = cli_utils.switch_exploration(
 81        config,
 82        strategy="egreedy",
 83        eps_start=1.0,
 84        eps_end=0.05,
 85        steps_anneal=1_000_000,
 86        randomize_eps=True,
 87    )
 88    agent_type = cli_utils.switch_agent(config, "agent", tau=0.004)
 89    cli_utils.use_config(config)
 90
 91    group_name = f"{args.run_name}_{args.seq_model}"
 92    for trial in range(args.trials):
 93        run_name = group_name + f"_trial_{trial}"
 94
 95        # create a dataset on disk. envs will write finished episodes here
 96        dset = DiskTrajDataset(
 97            dset_root=args.buffer_dir, dset_name=run_name, dset_max_size=12_500
 98        )
 99        # save checkpoints alongside the buffer
100        ckpt_dir = args.buffer_dir
101        # wrap environment
102        make_env = lambda: AMAGOEnv(
103            MetaFrozenLake(**lake_kwargs),
104            env_name=f"meta_frozen_lake_k{args.k_episodes}_{args.lake_size}x{args.lake_size}"
105            + ("_hard" if args.hard_mode else "_easy")
106            + ("_recover" if args.recover_mode else "_reset"),
107        )
108
109        experiment = amago.Experiment(
110            make_train_env=make_env,
111            make_val_env=make_env,
112            max_seq_len=max_seq_len,
113            traj_save_len=max_rollout_length,
114            dataset=dset,
115            ckpt_base_dir=args.buffer_dir,
116            agent_type=agent_type,
117            exploration_wrapper_type=exploration_wrapper_type,
118            tstep_encoder_type=tstep_encoder_type,
119            traj_encoder_type=traj_encoder_type,
120            run_name=run_name,
121            dloader_workers=10,
122            log_to_wandb=args.log,
123            wandb_group_name=group_name,
124            epochs=700 if not args.hard_mode else 900,
125            parallel_actors=32,
126            train_timesteps_per_epoch=max_rollout_length,
127            train_batches_per_epoch=1000,
128            val_interval=20,
129            val_timesteps_per_epoch=max_rollout_length * 2,
130            ckpt_interval=200,
131            env_mode="sync",
132        )
133
134        experiment.start()
135        experiment.learn()
136        experiment.evaluate_test(make_env, timesteps=10_000)
137        experiment.delete_buffer_from_disk()
138        if args.log:
139            wandb.finish()