Meta Frozen Lake

Meta Frozen Lake#

  1from argparse import ArgumentParser
  2
  3import amago
  4from amago.envs.builtin.toy_gym import MetaFrozenLake
  5from amago.envs import AMAGOEnv
  6from amago.loading import DiskTrajDataset
  7from amago import cli_utils
  8
  9
 10def add_cli(parser):
 11    parser.add_argument(
 12        "--seq_model",
 13        type=str,
 14        choices=["ff", "transformer", "rnn", "mamba"],
 15        required=True,
 16    )
 17    parser.add_argument("--run_name", type=str, required=True)
 18    parser.add_argument("--buffer_dir", type=str, required=True)
 19    parser.add_argument("--log", action="store_true")
 20    parser.add_argument("--trials", type=int, default=1)
 21    parser.add_argument("--lake_size", type=int, default=5)
 22    parser.add_argument("--k_episodes", type=int, default=15)
 23    parser.add_argument("--hard_mode", action="store_true")
 24    parser.add_argument("--recover_mode", action="store_true")
 25    parser.add_argument("--max_rollout_length", type=int, default=512)
 26    parser.add_argument("--max_seq_len", type=int, default=512)
 27    return parser
 28
 29
 30if __name__ == "__main__":
 31    parser = ArgumentParser()
 32    add_cli(parser)
 33    args = parser.parse_args()
 34
 35    if args.log:
 36        import wandb
 37
 38    config = {}
 39    # configure trajectory encoder (seq2seq memory model)
 40    traj_encoder_type = cli_utils.switch_traj_encoder(
 41        config,
 42        arch=args.seq_model,
 43        memory_size=128,
 44        layers=3,
 45    )
 46    # configure timestep encoder
 47    tstep_encoder_type = cli_utils.switch_tstep_encoder(
 48        config, arch="ff", n_layers=1, d_hidden=128, d_output=64, normalize_inputs=False
 49    )
 50
 51    # we're using the default exploration strategy but being overly verbose about it for the example
 52    exploration_wrapper_type = cli_utils.switch_exploration(
 53        config,
 54        strategy="egreedy",
 55        eps_start=1.0,
 56        eps_end=0.05,
 57        steps_anneal=1_000_000,
 58        randomize_eps=True,
 59    )
 60    cli_utils.use_config(config)
 61
 62    group_name = f"{args.run_name}_{args.seq_model}"
 63    for trial in range(args.trials):
 64        run_name = group_name + f"_trial_{trial}"
 65
 66        # create a dataset on disk. envs will write finished episodes here
 67        dset = DiskTrajDataset(
 68            dset_root=args.buffer_dir, dset_name=run_name, dset_max_size=12_500
 69        )
 70        # save checkpoints alongside the buffer
 71        ckpt_dir = args.buffer_dir
 72
 73        # wrap environment
 74        make_env = lambda: AMAGOEnv(
 75            MetaFrozenLake(
 76                k_episodes=args.k_episodes,
 77                size=args.lake_size,
 78                hard_mode=args.hard_mode,
 79                recover_mode=args.recover_mode,
 80            ),
 81            env_name=f"meta_frozen_lake_k{args.k_episodes}_{args.lake_size}x{args.lake_size}"
 82            + ("_hard" if args.hard_mode else "_easy")
 83            + ("_recover" if args.recover_mode else "_reset"),
 84        )
 85
 86        # create `Experiment`
 87        experiment = amago.Experiment(
 88            make_train_env=make_env,
 89            make_val_env=make_env,
 90            max_seq_len=args.max_seq_len,
 91            traj_save_len=args.max_rollout_length,
 92            dataset=dset,
 93            ckpt_base_dir=ckpt_dir,
 94            agent_type=amago.agent.Agent,
 95            exploration_wrapper_type=exploration_wrapper_type,
 96            tstep_encoder_type=tstep_encoder_type,
 97            traj_encoder_type=traj_encoder_type,
 98            run_name=run_name,
 99            dloader_workers=10,
100            log_to_wandb=args.log,
101            wandb_group_name=group_name,
102            epochs=700 if not args.hard_mode else 900,
103            parallel_actors=24,
104            train_timesteps_per_epoch=350,
105            train_batches_per_epoch=800,
106            val_interval=20,
107            val_timesteps_per_epoch=args.max_rollout_length * 2,
108            ckpt_interval=50,
109            env_mode="sync",
110        )
111
112        # start experiment (build envs, policies, etc.)
113        experiment.start()
114        # run training
115        experiment.learn()
116        experiment.evaluate_test(make_env, timesteps=10_000)
117        experiment.delete_buffer_from_disk()
118        wandb.finish()