HalfCheetah(v4)-Velocity

HalfCheetah(v4)-Velocity#

  1from argparse import ArgumentParser
  2
  3import wandb
  4
  5import amago
  6from amago.envs import AMAGOEnv
  7from amago.envs.builtin.half_cheetah_v4_vel import HalfCheetahV4_MetaVelocity
  8from amago import cli_utils
  9
 10
 11def add_cli(parser):
 12    parser.add_argument(
 13        "--policy_seq_len", type=int, default=32, help="Policy sequence length."
 14    )
 15    parser.add_argument(
 16        "--eval_episodes_per_actor",
 17        type=int,
 18        default=1,
 19        help="Validation episodes per parallel actor.",
 20    )
 21    parser.add_argument(
 22        "--task_min_velocity",
 23        type=float,
 24        default=0.0,
 25        help="Min running velocity the cheetah needs to be capable of to solve the meta-learning problem. Original benchmark used 0.",
 26    )
 27    parser.add_argument(
 28        "--task_max_velocity",
 29        type=float,
 30        default=3.0,
 31        help="Max running velocity the cheetah needs to be capable of to solve the meta-learning problem. Original benchmark used 3. Agents in the default locomotion env (no reward randomization) reach > 10.",
 32    )
 33    parser.add_argument(
 34        "--inner_episode_steps",
 35        type=int,
 36        default=200,
 37        help="Step horizon of each inner episode. Default 200 (combined with the default --k_train_episodes=3) keeps total trial length at 600 steps. Set 1000 with --k_train_episodes=1 to recover the unwrapped task.",
 38    )
 39    parser.add_argument(
 40        "--k_train_episodes",
 41        type=int,
 42        default=3,
 43        help="Inner episodes per meta-trial during training. Default 3 makes the env a true meta-RL trial: a single hidden target velocity persists across 3 inner episodes (soft resets between). Set 1 to recover the unwrapped task.",
 44    )
 45    parser.add_argument(
 46        "--k_eval_episodes",
 47        type=int,
 48        default=None,
 49        help="Inner episodes per meta-trial at eval time. Defaults to --k_train_episodes. Larger values (e.g. 10) probe how well the agent keeps adapting beyond its training horizon.",
 50    )
 51    return parser
 52
 53
 54"""
 55Because this task is so similar to the other gymnasium examples, this example script is overly
 56verbose about showing how you could customize the environment and create a train/test split.
 57
 58If you don't edit anything, this only becomes a longer way to train/test on the default task
 59distribution (which is to sample a velocity uniformly between: [args.task_min_velocity, args.task_max_velocity])
 60"""
 61
 62
 63class MyCustomHalfCheetahTrain(HalfCheetahV4_MetaVelocity):
 64    def sample_target_velocity(self) -> float:
 65        # be sure to use `random` or be careful about np default_rng to ensure
 66        # tasks are different across async parallel actors!
 67        vel = super().sample_target_velocity()  # random.uniform(min_vel, max_vel)
 68        return vel
 69
 70
 71class MyCustomHalfCheetahEval(HalfCheetahV4_MetaVelocity):
 72    def sample_target_velocity(self) -> float:
 73        vel = super().sample_target_velocity()
 74        # or, to create OOD eval tasks:
 75        # vel = random.uniform(self.task_min_velocity, self.task_max_velocity * 10.0)
 76        # or random.choice([0., 1., self.task_max_velocity * 1.2]), etc.
 77        return vel
 78
 79
 80if __name__ == "__main__":
 81    parser = ArgumentParser()
 82    cli_utils.add_common_cli(parser)
 83    add_cli(parser)
 84    args = parser.parse_args()
 85
 86    k_eval = (
 87        args.k_eval_episodes
 88        if args.k_eval_episodes is not None
 89        else args.k_train_episodes
 90    )
 91
 92    def make_train_env():
 93        return AMAGOEnv(
 94            MyCustomHalfCheetahTrain(
 95                task_min_velocity=args.task_min_velocity,
 96                task_max_velocity=args.task_max_velocity,
 97                max_episode_steps=args.inner_episode_steps,
 98                k_episodes=args.k_train_episodes,
 99            ),
100            env_name="HalfCheetahV4Velocity",
101        )
102
103    def make_val_env():
104        return AMAGOEnv(
105            MyCustomHalfCheetahEval(
106                task_min_velocity=args.task_min_velocity,
107                task_max_velocity=args.task_max_velocity,
108                max_episode_steps=args.inner_episode_steps,
109                k_episodes=k_eval,
110            ),
111            env_name="HalfCheetahV4Velocity",
112        )
113
114    config = {
115        "amago.nets.traj_encoders.TformerTrajEncoder.pos_emb": "rope",
116    }
117    # switch sequence model
118    traj_encoder_type = cli_utils.switch_traj_encoder(
119        config,
120        arch=args.traj_encoder,
121        memory_size=args.memory_size,
122        layers=args.memory_layers,
123    )
124    # switch agent
125    agent_type = cli_utils.switch_agent(
126        config,
127        args.agent_type,
128        reward_multiplier=1.0,  # gym locomotion returns are already large
129        gamma=0.99,  # locomotion policies don't need long horizons - fall back to the default
130        tau=0.005,
131    )
132    # "egreedy" exploration in continuous control is just the epsilon-scheduled random (normal)
133    # noise from most TD3/DPPG implementations.
134    exploration_type = cli_utils.switch_exploration(
135        config, "egreedy", steps_anneal=500_000
136    )
137    cli_utils.use_config(config, args.configs)
138
139    group_name = args.run_name
140    for trial in range(args.trials):
141        run_name = group_name + f"_trial_{trial}"
142        experiment = cli_utils.create_experiment_from_cli(
143            args,
144            make_train_env=make_train_env,  # different train/val envs
145            make_val_env=make_val_env,
146            max_seq_len=args.policy_seq_len,
147            traj_save_len=args.policy_seq_len * 6,
148            run_name=run_name,
149            tstep_encoder_type=amago.nets.tstep_encoders.FFTstepEncoder,
150            traj_encoder_type=traj_encoder_type,
151            exploration_wrapper_type=exploration_type,
152            agent_type=agent_type,
153            group_name=group_name,
154            val_timesteps_per_epoch=args.eval_episodes_per_actor
155            * (args.inner_episode_steps * k_eval + 1),
156            grad_clip=2.0,
157            learning_rate=3e-4,
158        )
159        experiment.start()
160        if args.ckpt is not None:
161            experiment.load_checkpoint(args.ckpt)
162        experiment.learn()
163        experiment.evaluate_test(make_val_env, timesteps=10_000, render=False)
164        experiment.delete_buffer_from_disk()
165        wandb.finish()