T-Maze

T-Maze#

  1from argparse import ArgumentParser
  2
  3import wandb
  4
  5from amago.envs import AMAGOEnv
  6from amago.envs.builtin.tmaze import TMazeAltPassive, TMazeAltActive
  7from amago.envs.exploration import EpsilonGreedy
  8from amago.cli_utils import *
  9
 10
 11def add_cli(parser):
 12    parser.add_argument("--horizon", type=int, required=True)
 13    return parser
 14
 15
 16@gin.configurable
 17class TMazeExploration(EpsilonGreedy):
 18    """
 19    The Tmaze environment is meant to evaluate recall over long context lengths without
 20    testing exploration, but it does this by requiring horizon - 1 deterministic actions
 21    to create a gap between the timestep that reveals the correct action and the timestep
 22    it is taken. This unintentionally creates a worst-case scenario for epsilon greedy
 23    exploration. We use this epsilon greedy exploration schedule to answer the
 24    central memory question while fixing the sample efficiency problems it creates.
 25
 26    https://github.com/twni2016/Memory-RL/issues/1
 27    """
 28
 29    def __init__(
 30        self,
 31        env,
 32        start_window=0,
 33        end_window=3,
 34        horizon: int = gin.REQUIRED,
 35        eps_start=1.0,
 36        eps_end=0.01,
 37        steps_anneal=100_000,
 38    ):
 39        self.start_window = start_window
 40        self.end_window = end_window
 41        self.horizon = horizon
 42        super().__init__(
 43            env,
 44            eps_start=eps_start,
 45            eps_end=eps_end,
 46            steps_anneal=steps_anneal,
 47        )
 48
 49    def current_eps(self, local_step: int):
 50        current = super().current_eps(local_step)
 51        if (
 52            local_step > self.start_window
 53            and local_step < self.horizon - self.end_window
 54        ):
 55            # low exploration during the easy corridor section; regular during the
 56            # interesting early and late timesteps.
 57            current[:] = 0.5 / self.horizon
 58        return current
 59
 60
 61if __name__ == "__main__":
 62    parser = ArgumentParser()
 63    add_common_cli(parser)
 64    add_cli(parser)
 65    args = parser.parse_args()
 66
 67    config = {
 68        "TMazeExploration.horizon": args.horizon,
 69    }
 70    traj_encoder_type = switch_traj_encoder(
 71        config,
 72        arch=args.traj_encoder,
 73        memory_size=args.memory_size,
 74        layers=args.memory_layers,
 75    )
 76    tstep_encoder_type = switch_tstep_encoder(
 77        config,
 78        arch="ff",
 79        n_layers=2,
 80        d_hidden=128,
 81        d_output=128,
 82        normalize_inputs=False,
 83    )
 84    agent_type = switch_agent(
 85        config, args.agent_type, reward_multiplier=100.0, gamma=0.9999
 86    )
 87    use_config(config, args.configs)
 88
 89    group_name = f"{args.run_name}_TMazePassive_H{args.horizon}"
 90    for trial in range(args.trials):
 91        run_name = group_name + f"_trial_{trial}"
 92        make_env = lambda: AMAGOEnv(
 93            env=TMazeAltPassive(
 94                corridor_length=args.horizon, penalty=-1.0 / args.horizon
 95            ),
 96            env_name=f"TMazePassive-H{args.horizon}",
 97        )
 98        experiment = create_experiment_from_cli(
 99            args,
100            make_train_env=make_env,
101            make_val_env=make_env,
102            max_seq_len=args.horizon + 1,
103            traj_save_len=args.horizon + 1,
104            group_name=group_name,
105            run_name=run_name,
106            tstep_encoder_type=tstep_encoder_type,
107            traj_encoder_type=traj_encoder_type,
108            agent_type=agent_type,
109            val_timesteps_per_epoch=args.horizon + 1,
110            sample_actions=False,  # even softmax prob .999 isn't good enough for this env...
111            exploration_wrapper_type=TMazeExploration,
112        )
113        switch_async_mode(experiment, args.mode)
114        experiment.start()
115        if args.ckpt is not None:
116            experiment.load_checkpoint(args.ckpt)
117        experiment.learn()
118        experiment.evaluate_test(make_env, timesteps=args.horizon * 5, render=False)
119        experiment.delete_buffer_from_disk()
120        wandb.finish()