Dark Room Key Door

Dark Room Key Door#

  1from argparse import ArgumentParser
  2
  3import wandb
  4
  5from amago.envs.builtin.toy_gym import RoomKeyDoor
  6from amago.envs import AMAGOEnv
  7from amago import cli_utils
  8
  9
 10def add_cli(parser):
 11    parser.add_argument(
 12        "--meta_horizon",
 13        type=int,
 14        default=500,
 15        help="Total meta-adaptation timestep budget for the agent to explore the same room layout.",
 16    )
 17    parser.add_argument(
 18        "--room_size",
 19        type=int,
 20        default=8,
 21        help="Size of the room. Exploration is sparse and difficulty scales quickly with room size.",
 22    )
 23    parser.add_argument(
 24        "--episode_length",
 25        type=int,
 26        default=50,
 27        help="Maximum length of a single episode in the environment.",
 28    )
 29    parser.add_argument(
 30        "--light_room_observation",
 31        action="store_true",
 32        help="Demonstrate how meta-RL relies on partial observability by revealing the goal location as part of the observation. This version of the environment can be solved without memory!",
 33    )
 34    parser.add_argument(
 35        "--randomize_actions",
 36        action="store_true",
 37        help="Randomize the agent's action space to make the task harder.",
 38    )
 39    return parser
 40
 41
 42if __name__ == "__main__":
 43    parser = ArgumentParser()
 44    cli_utils.add_common_cli(parser)
 45    add_cli(parser)
 46    args = parser.parse_args()
 47
 48    config = {}
 49    tstep_encoder_type = cli_utils.switch_tstep_encoder(
 50        config, arch="ff", n_layers=2, d_hidden=128, d_output=64
 51    )
 52    traj_encoder_type = cli_utils.switch_traj_encoder(
 53        config,
 54        arch=args.traj_encoder,
 55        memory_size=args.memory_size,
 56        layers=args.memory_layers,
 57    )
 58    agent_type = cli_utils.switch_agent(
 59        config, args.agent_type, reward_multiplier=100.0
 60    )
 61    # the fancier exploration schedule mentioned in the appendix can help
 62    # when the domain is a true meta-RL problem and the "horizon" time limit
 63    # (above) is actually relevant for resetting the task.
 64    exploration_type = cli_utils.switch_exploration(
 65        config, "bilevel", steps_anneal=500_000, rollout_horizon=args.meta_horizon
 66    )
 67    cli_utils.use_config(config, args.configs)
 68
 69    group_name = f"{args.run_name}_dark_key_door"
 70    for trial in range(args.trials):
 71        run_name = group_name + f"_trial_{trial}"
 72        make_train_env = lambda: AMAGOEnv(
 73            env=RoomKeyDoor(
 74                size=args.room_size,
 75                max_episode_steps=args.episode_length,
 76                meta_rollout_horizon=args.meta_horizon,
 77                dark=not args.light_room_observation,
 78                randomize_actions=args.randomize_actions,
 79            ),
 80            env_name=f"Dark-Key-To-Door-{args.room_size}x{args.room_size}",
 81        )
 82        experiment = cli_utils.create_experiment_from_cli(
 83            args,
 84            agent_type=agent_type,
 85            tstep_encoder_type=tstep_encoder_type,
 86            traj_encoder_type=traj_encoder_type,
 87            make_train_env=make_train_env,
 88            make_val_env=make_train_env,
 89            max_seq_len=args.meta_horizon,
 90            traj_save_len=args.meta_horizon,
 91            group_name=group_name,
 92            run_name=run_name,
 93            val_timesteps_per_epoch=args.meta_horizon * 4,
 94            exploration_wrapper_type=exploration_type,
 95        )
 96        experiment = cli_utils.switch_async_mode(experiment, args.mode)
 97        experiment.start()
 98        if args.ckpt is not None:
 99            experiment.load_checkpoint(args.ckpt)
100        experiment.learn()
101        experiment.evaluate_test(make_train_env, timesteps=20_000, render=False)
102        experiment.delete_buffer_from_disk()
103        wandb.finish()