Meta Frozen Lake

Meta Frozen Lake#

  1import amago
  2from amago.envs.builtin.toy_gym import MetaFrozenLake
  3from amago.envs import AMAGOEnv
  4from amago.loading import DiskTrajDataset
  5from amago.cli_utils import *
  6
  7
  8def add_cli(parser):
  9    parser.add_argument(
 10        "--seq_model",
 11        type=str,
 12        choices=["ff", "transformer", "rnn", "mamba"],
 13        required=True,
 14    )
 15    parser.add_argument("--run_name", type=str, required=True)
 16    parser.add_argument("--buffer_dir", type=str, required=True)
 17    parser.add_argument("--log", action="store_true")
 18    parser.add_argument("--trials", type=int, default=1)
 19    parser.add_argument("--lake_size", type=int, default=5)
 20    parser.add_argument("--k_episodes", type=int, default=15)
 21    parser.add_argument("--hard_mode", action="store_true")
 22    parser.add_argument("--recover_mode", action="store_true")
 23    parser.add_argument("--max_rollout_length", type=int, default=512)
 24    parser.add_argument("--max_seq_len", type=int, default=512)
 25    return parser
 26
 27
 28if __name__ == "__main__":
 29    parser = ArgumentParser()
 30    add_cli(parser)
 31    args = parser.parse_args()
 32
 33    if args.log:
 34        import wandb
 35
 36    config = {}
 37    # configure trajectory encoder (seq2seq memory model)
 38    traj_encoder_type = switch_traj_encoder(
 39        config,
 40        arch=args.seq_model,
 41        memory_size=128,
 42        layers=3,
 43    )
 44    # configure timestep encoder
 45    tstep_encoder_type = switch_tstep_encoder(
 46        config, arch="ff", n_layers=1, d_hidden=128, d_output=64, normalize_inputs=False
 47    )
 48
 49    # we're using the default exploration strategy but being overly verbose about it for the example
 50    exploration_wrapper_type = switch_exploration(
 51        config,
 52        strategy="egreedy",
 53        eps_start=1.0,
 54        eps_end=0.05,
 55        steps_anneal=1_000_000,
 56        randomize_eps=True,
 57    )
 58    use_config(config)
 59
 60    group_name = f"{args.run_name}_{args.seq_model}"
 61    for trial in range(args.trials):
 62        run_name = group_name + f"_trial_{trial}"
 63
 64        # create a dataset on disk. envs will write finished episodes here
 65        dset = DiskTrajDataset(
 66            dset_root=args.buffer_dir, dset_name=run_name, dset_max_size=12_500
 67        )
 68        # save checkpoints alongside the buffer
 69        ckpt_dir = args.buffer_dir
 70
 71        # wrap environment
 72        make_env = lambda: AMAGOEnv(
 73            MetaFrozenLake(
 74                k_episodes=args.k_episodes,
 75                size=args.lake_size,
 76                hard_mode=args.hard_mode,
 77                recover_mode=args.recover_mode,
 78            ),
 79            env_name=f"meta_frozen_lake_k{args.k_episodes}_{args.lake_size}x{args.lake_size}"
 80            + ("_hard" if args.hard_mode else "_easy")
 81            + ("_recover" if args.recover_mode else "_reset"),
 82        )
 83
 84        # create `Experiment`
 85        experiment = amago.Experiment(
 86            make_train_env=make_env,
 87            make_val_env=make_env,
 88            max_seq_len=args.max_seq_len,
 89            traj_save_len=args.max_rollout_length,
 90            dataset=dset,
 91            ckpt_base_dir=ckpt_dir,
 92            agent_type=amago.agent.Agent,
 93            exploration_wrapper_type=exploration_wrapper_type,
 94            tstep_encoder_type=tstep_encoder_type,
 95            traj_encoder_type=traj_encoder_type,
 96            run_name=run_name,
 97            dloader_workers=10,
 98            log_to_wandb=args.log,
 99            wandb_group_name=group_name,
100            epochs=700 if not args.hard_mode else 900,
101            parallel_actors=24,
102            train_timesteps_per_epoch=350,
103            train_batches_per_epoch=800,
104            val_interval=20,
105            val_timesteps_per_epoch=args.max_rollout_length * 2,
106            ckpt_interval=50,
107            env_mode="sync",
108        )
109
110        # start experiment (build envs, policies, etc.)
111        experiment.start()
112        # run training
113        experiment.learn()
114        experiment.evaluate_test(make_env, timesteps=10_000)
115        experiment.delete_buffer_from_disk()
116        wandb.finish()