T-Maze

T-Maze#

  1from argparse import ArgumentParser
  2
  3import wandb
  4import gin
  5
  6from amago.envs import AMAGOEnv
  7from amago.envs.builtin.tmaze import TMazeAltPassive, TMazeAltActive
  8from amago.envs.exploration import EpsilonGreedy
  9from amago import cli_utils
 10
 11
 12def add_cli(parser):
 13    parser.add_argument("--horizon", type=int, required=True)
 14    return parser
 15
 16
 17@gin.configurable
 18class TMazeExploration(EpsilonGreedy):
 19    """
 20    The Tmaze environment is meant to evaluate recall over long context lengths without
 21    testing exploration, but it does this by requiring horizon - 1 deterministic actions
 22    to create a gap between the timestep that reveals the correct action and the timestep
 23    it is taken. This unintentionally creates a worst-case scenario for epsilon greedy
 24    exploration. We use this epsilon greedy exploration schedule to answer the
 25    central memory question while fixing the sample efficiency problems it creates.
 26
 27    https://github.com/twni2016/Memory-RL/issues/1
 28    """
 29
 30    def __init__(
 31        self,
 32        env,
 33        start_window=0,
 34        end_window=3,
 35        horizon: int = gin.REQUIRED,
 36        eps_start=1.0,
 37        eps_end=0.01,
 38        steps_anneal=100_000,
 39    ):
 40        self.start_window = start_window
 41        self.end_window = end_window
 42        self.horizon = horizon
 43        super().__init__(
 44            env,
 45            eps_start=eps_start,
 46            eps_end=eps_end,
 47            steps_anneal=steps_anneal,
 48        )
 49
 50    def current_eps(self, local_step: int):
 51        current = super().current_eps(local_step)
 52        if (
 53            local_step > self.start_window
 54            and local_step < self.horizon - self.end_window
 55        ):
 56            # low exploration during the easy corridor section; regular during the
 57            # interesting early and late timesteps.
 58            current[:] = 0.5 / self.horizon
 59        return current
 60
 61
 62if __name__ == "__main__":
 63    parser = ArgumentParser()
 64    cli_utils.add_common_cli(parser)
 65    add_cli(parser)
 66    args = parser.parse_args()
 67
 68    config = {
 69        "TMazeExploration.horizon": args.horizon,
 70    }
 71    traj_encoder_type = cli_utils.switch_traj_encoder(
 72        config,
 73        arch=args.traj_encoder,
 74        memory_size=args.memory_size,
 75        layers=args.memory_layers,
 76    )
 77    tstep_encoder_type = cli_utils.switch_tstep_encoder(
 78        config,
 79        arch="ff",
 80        n_layers=2,
 81        d_hidden=128,
 82        d_output=128,
 83        normalize_inputs=False,
 84    )
 85    agent_type = cli_utils.switch_agent(
 86        config, args.agent_type, reward_multiplier=100.0, gamma=0.9999
 87    )
 88    cli_utils.use_config(config, args.configs)
 89
 90    group_name = f"{args.run_name}_TMazePassive_H{args.horizon}"
 91    for trial in range(args.trials):
 92        run_name = group_name + f"_trial_{trial}"
 93        make_env = lambda: AMAGOEnv(
 94            env=TMazeAltPassive(
 95                corridor_length=args.horizon, penalty=-1.0 / args.horizon
 96            ),
 97            env_name=f"TMazePassive-H{args.horizon}",
 98        )
 99        experiment = cli_utils.create_experiment_from_cli(
100            args,
101            make_train_env=make_env,
102            make_val_env=make_env,
103            max_seq_len=args.horizon + 1,
104            traj_save_len=args.horizon + 1,
105            group_name=group_name,
106            run_name=run_name,
107            tstep_encoder_type=tstep_encoder_type,
108            traj_encoder_type=traj_encoder_type,
109            agent_type=agent_type,
110            val_timesteps_per_epoch=args.horizon + 1,
111            sample_actions=False,  # even softmax prob .999 isn't good enough for this env...
112            exploration_wrapper_type=TMazeExploration,
113        )
114        experiment = cli_utils.switch_async_mode(experiment, args.mode)
115        experiment.start()
116        if args.ckpt is not None:
117            experiment.load_checkpoint(args.ckpt)
118        experiment.learn()
119        experiment.evaluate_test(make_env, timesteps=args.horizon * 5, render=False)
120        experiment.delete_buffer_from_disk()
121        wandb.finish()