HalfCheetah(v4)-Velocity

HalfCheetah(v4)-Velocity#

  1from argparse import ArgumentParser
  2import math
  3import random
  4
  5import wandb
  6
  7import amago
  8from amago.envs import AMAGOEnv
  9from amago.envs.builtin.half_cheetah_v4_vel import HalfCheetahV4_MetaVelocity
 10from amago import cli_utils
 11
 12
 13def add_cli(parser):
 14    parser.add_argument(
 15        "--policy_seq_len", type=int, default=32, help="Policy sequence length."
 16    )
 17    parser.add_argument(
 18        "--eval_episodes_per_actor",
 19        type=int,
 20        default=1,
 21        help="Validation episodes per parallel actor.",
 22    )
 23    parser.add_argument(
 24        "--task_min_velocity",
 25        type=float,
 26        default=0.0,
 27        help="Min running velocity the cheetah needs to be capable of to solve the meta-learning problem. Original benchmark used 0.",
 28    )
 29    parser.add_argument(
 30        "--task_max_velocity",
 31        type=float,
 32        default=3.0,
 33        help="Max running velocity the cheetah needs to be capable of to solve the meta-learning problem. Original benchmark used 3. Agents in the default locomotion env (no reward randomization) reach > 10.",
 34    )
 35    return parser
 36
 37
 38"""
 39Because this task is so similar to the other gymnasium examples, this example script is overly
 40verbose about showing how you could customize the environment and create a train/test split.
 41
 42If you don't edit anything, this only becomes a longer way to train/test on the default task
 43distribution (which is to sample a velocity uniformly between: [args.task_min_velocity, args.task_max_velocity])
 44"""
 45
 46
 47class MyCustomHalfCheetahTrain(HalfCheetahV4_MetaVelocity):
 48    def sample_target_velocity(self) -> float:
 49        # be sure to use `random` or be careful about np default_rng to ensure
 50        # tasks are different across async parallel actors!
 51        vel = super().sample_target_velocity()  # random.uniform(min_vel, max_vel)
 52        return vel
 53
 54
 55class MyCustomHalfCheetahEval(HalfCheetahV4_MetaVelocity):
 56    def sample_target_velocity(self) -> float:
 57        vel = super().sample_target_velocity()
 58        # or, to create OOD eval tasks:
 59        # vel = random.uniform(self.task_min_velocity, self.task_max_velocity * 10.0)
 60        # or random.choice([0., 1., self.task_max_velocity * 1.2]), etc.
 61        return vel
 62
 63
 64class AMAGOEnvWithVelocityName(AMAGOEnv):
 65    """
 66    Every eval metric gets logged based on the current
 67    `env_name`. You could use this to log metrics for
 68    different tasks separately. They get averaged over
 69    all the evals with the same name, so you want a discrete
 70    number of names that will get sample sizes > 1.
 71    """
 72
 73    @property
 74    def env_name(self) -> str:
 75        current_task_vel = self.env.unwrapped.target_velocity
 76        # need to discretize this somehow; just one example
 77        low, high = math.floor(current_task_vel), math.ceil(current_task_vel)
 78        return f"HalfCheetahVelocity-Vel-[{low}, {high}]"
 79
 80
 81if __name__ == "__main__":
 82    parser = ArgumentParser()
 83    cli_utils.add_common_cli(parser)
 84    add_cli(parser)
 85    args = parser.parse_args()
 86
 87    # setup environment
 88    make_train_env = lambda: AMAGOEnvWithVelocityName(
 89        MyCustomHalfCheetahTrain(
 90            task_min_velocity=args.task_min_velocity,
 91            task_max_velocity=args.task_max_velocity,
 92        ),
 93        # the env_name is totally arbitrary and only impacts logging / data filenames
 94        env_name=f"HalfCheetahV4Velocity",
 95    )
 96
 97    make_val_env = lambda: AMAGOEnvWithVelocityName(
 98        MyCustomHalfCheetahEval(
 99            task_min_velocity=args.task_min_velocity,
100            task_max_velocity=args.task_max_velocity,
101        ),
102        # this would get replaced by the env_name property
103        # defined above.
104        env_name=f"HalfCheetahV4VelocityEval",
105    )
106
107    config = {}
108    # switch sequence model
109    traj_encoder_type = cli_utils.switch_traj_encoder(
110        config,
111        arch=args.traj_encoder,
112        memory_size=args.memory_size,
113        layers=args.memory_layers,
114    )
115    # switch agent
116    agent_type = cli_utils.switch_agent(
117        config,
118        args.agent_type,
119        reward_multiplier=1.0,  # gym locomotion returns are already large
120        gamma=0.99,  # locomotion policies don't need long horizons - fall back to the default
121        tau=0.005,
122    )
123    # "egreedy" exploration in continuous control is just the epsilon-scheduled random (normal)
124    # noise from most TD3/DPPG implementations.
125    exploration_type = cli_utils.switch_exploration(
126        config, "egreedy", steps_anneal=500_000
127    )
128    cli_utils.use_config(config, args.configs)
129
130    group_name = args.run_name
131    for trial in range(args.trials):
132        run_name = group_name + f"_trial_{trial}"
133        experiment = cli_utils.create_experiment_from_cli(
134            args,
135            make_train_env=make_train_env,  # different train/val envs
136            make_val_env=make_val_env,
137            max_seq_len=args.policy_seq_len,
138            traj_save_len=args.policy_seq_len * 6,
139            run_name=run_name,
140            tstep_encoder_type=amago.nets.tstep_encoders.FFTstepEncoder,
141            traj_encoder_type=traj_encoder_type,
142            exploration_wrapper_type=exploration_type,
143            agent_type=agent_type,
144            group_name=group_name,
145            val_timesteps_per_epoch=args.eval_episodes_per_actor * 1001,
146            grad_clip=2.0,
147            learning_rate=3e-4,
148        )
149        experiment.start()
150        if args.ckpt is not None:
151            experiment.load_checkpoint(args.ckpt)
152        experiment.learn()
153        experiment.evaluate_test(make_val_env, timesteps=10_000, render=False)
154        experiment.delete_buffer_from_disk()
155        wandb.finish()