Dark Room Key Door

Dark Room Key Door#

  1from argparse import ArgumentParser
  2
  3import wandb
  4
  5from amago.envs.builtin.toy_gym import RoomKeyDoor
  6from amago.envs import AMAGOEnv
  7from amago import cli_utils
  8
  9
 10def add_cli(parser):
 11    parser.add_argument(
 12        "--k_episodes",
 13        type=int,
 14        default=8,
 15        help="Number of episodes per meta-rollout. Effective sequence length = k_episodes * episode_length.",
 16    )
 17    parser.add_argument(
 18        "--room_size",
 19        type=int,
 20        default=8,
 21        help="Size of the room. Exploration is sparse and difficulty scales quickly with room size.",
 22    )
 23    parser.add_argument(
 24        "--episode_length",
 25        type=int,
 26        default=50,
 27        help="Maximum length of a single episode in the environment.",
 28    )
 29    parser.add_argument(
 30        "--light_room_observation",
 31        action="store_true",
 32        help="Demonstrate how meta-RL relies on partial observability by revealing the goal location as part of the observation. This version of the environment can be solved without memory!",
 33    )
 34    parser.add_argument(
 35        "--randomize_actions",
 36        action="store_true",
 37        help="Randomize the agent's action space to make the task harder.",
 38    )
 39    parser.add_argument(
 40        "--finite_horizon",
 41        action="store_true",
 42        help="Use finite-horizon mode: include time in observations and signal meta-done as terminated. Default is infinite-horizon (no time in obs, meta-done as truncated).",
 43    )
 44    return parser
 45
 46
 47if __name__ == "__main__":
 48    parser = ArgumentParser()
 49    cli_utils.add_common_cli(parser)
 50    add_cli(parser)
 51    args = parser.parse_args()
 52
 53    config = {}
 54    tstep_encoder_type = cli_utils.switch_tstep_encoder(
 55        config,
 56        arch="ff",
 57        n_layers=2,
 58        d_hidden=128,
 59        d_output=64,
 60        specify_obs_keys=["observed", "prev_action", "prev_reward"],
 61        hide_rl2s=True,
 62        normalize_inputs=False,
 63    )
 64    traj_encoder_type = cli_utils.switch_traj_encoder(
 65        config,
 66        arch=args.traj_encoder,
 67        memory_size=args.memory_size,
 68        layers=args.memory_layers,
 69        pos_emb="rope",
 70    )
 71    agent_type = cli_utils.switch_agent(
 72        config, args.agent_type, reward_multiplier=100.0
 73    )
 74    horizon_type = "finite" if args.finite_horizon else "infinite"
 75    dummy_env = RoomKeyDoor(
 76        size=args.room_size,
 77        max_episode_steps=args.episode_length,
 78        k_episodes=args.k_episodes,
 79        horizon_type=horizon_type,
 80    )
 81    meta_horizon = dummy_env.meta_horizon
 82    args.timesteps_per_epoch = meta_horizon
 83    # the fancier exploration schedule mentioned in the appendix can help
 84    # when the domain is a true meta-RL problem and the "horizon" time limit
 85    # (above) is actually relevant for resetting the task.
 86    exploration_type = cli_utils.switch_exploration(
 87        config, "bilevel", steps_anneal=500_000, rollout_horizon=meta_horizon
 88    )
 89    cli_utils.use_config(config, args.configs)
 90
 91    group_name = f"{args.run_name}_dark_key_door"
 92    for trial in range(args.trials):
 93        run_name = group_name + f"_trial_{trial}"
 94        make_train_env = lambda: AMAGOEnv(
 95            env=RoomKeyDoor(
 96                size=args.room_size,
 97                max_episode_steps=args.episode_length,
 98                k_episodes=args.k_episodes,
 99                dark=not args.light_room_observation,
100                randomize_actions=args.randomize_actions,
101                horizon_type=horizon_type,
102            ),
103            env_name=f"Dark-Key-To-Door-{args.room_size}x{args.room_size}-{horizon_type}",
104        )
105        experiment = cli_utils.create_experiment_from_cli(
106            args,
107            agent_type=agent_type,
108            tstep_encoder_type=tstep_encoder_type,
109            traj_encoder_type=traj_encoder_type,
110            make_train_env=make_train_env,
111            make_val_env=make_train_env,
112            max_seq_len=meta_horizon,
113            traj_save_len=meta_horizon * 10,
114            group_name=group_name,
115            run_name=run_name,
116            val_timesteps_per_epoch=meta_horizon * 4,
117            exploration_wrapper_type=exploration_type,
118            stagger_traj_file_lengths=False,
119            wandb_project="z-room-key-door",
120        )
121        experiment = cli_utils.switch_async_mode(experiment, args.mode)
122        experiment.start()
123        if args.ckpt is not None:
124            experiment.load_checkpoint(args.ckpt)
125        experiment.learn()
126        experiment.evaluate_test(make_train_env, timesteps=20_000, render=False)
127        experiment.delete_buffer_from_disk()
128        wandb.finish()