HalfCheetah(v4)-Velocity

HalfCheetah(v4)-Velocity#

  1from argparse import ArgumentParser
  2import math
  3import random
  4
  5import wandb
  6
  7import amago
  8from amago.envs import AMAGOEnv
  9from amago.envs.builtin.half_cheetah_v4_vel import HalfCheetahV4_MetaVelocity
 10from amago.cli_utils import *
 11
 12
 13def add_cli(parser):
 14    parser.add_argument(
 15        "--policy_seq_len", type=int, default=32, help="Policy sequence length."
 16    )
 17    parser.add_argument(
 18        "--eval_episodes_per_actor",
 19        type=int,
 20        default=1,
 21        help="Validation episodes per parallel actor.",
 22    )
 23    parser.add_argument(
 24        "--task_min_velocity",
 25        type=float,
 26        default=0.0,
 27        help="Min running velocity the cheetah needs to be capable of to solve the meta-learning problem. Original benchmark used 0.",
 28    )
 29    parser.add_argument(
 30        "--task_max_velocity",
 31        type=float,
 32        default=3.0,
 33        help="Max running velocity the cheetah needs to be capable of to solve the meta-learning problem. Original benchmark used 3. Agents in the default locomotion env (no reward randomization) reach > 10.",
 34    )
 35    return parser
 36
 37
 38"""
 39Because this task is so similar to the other gymnasium examples, this example script is overly
 40verbose about showing how you could customize the environment and create a train/test split.
 41
 42If you don't edit anything, this only becomes a longer way to train/test on the default task
 43distribution (which is to sample a velocity uniformly between: [args.task_min_velocity, args.task_max_velocity])
 44"""
 45
 46
 47class MyCustomHalfCheetahTrain(HalfCheetahV4_MetaVelocity):
 48
 49    def sample_target_velocity(self) -> float:
 50        # be sure to use `random` or be careful about np default_rng to ensure
 51        # tasks are different across async parallel actors!
 52        vel = super().sample_target_velocity()  # random.uniform(min_vel, max_vel)
 53        return vel
 54
 55
 56class MyCustomHalfCheetahEval(HalfCheetahV4_MetaVelocity):
 57
 58    def sample_target_velocity(self) -> float:
 59        vel = super().sample_target_velocity()
 60        # or, to create OOD eval tasks:
 61        # vel = random.uniform(self.task_min_velocity, self.task_max_velocity * 10.0)
 62        # or random.choice([0., 1., self.task_max_velocity * 1.2]), etc.
 63        return vel
 64
 65
 66class AMAGOEnvWithVelocityName(AMAGOEnv):
 67    """
 68    Every eval metric gets logged based on the current
 69    `env_name`. You could use this to log metrics for
 70    different tasks separately. They get averaged over
 71    all the evals with the same name, so you want a discrete
 72    number of names that will get sample sizes > 1.
 73    """
 74
 75    @property
 76    def env_name(self) -> str:
 77        current_task_vel = self.env.unwrapped.target_velocity
 78        # need to discretize this somehow; just one example
 79        low, high = math.floor(current_task_vel), math.ceil(current_task_vel)
 80        return f"HalfCheetahVelocity-Vel-[{low}, {high}]"
 81
 82
 83if __name__ == "__main__":
 84    parser = ArgumentParser()
 85    add_common_cli(parser)
 86    add_cli(parser)
 87    args = parser.parse_args()
 88
 89    # setup environment
 90    make_train_env = lambda: AMAGOEnvWithVelocityName(
 91        MyCustomHalfCheetahTrain(
 92            task_min_velocity=args.task_min_velocity,
 93            task_max_velocity=args.task_max_velocity,
 94        ),
 95        # the env_name is totally arbitrary and only impacts logging / data filenames
 96        env_name=f"HalfCheetahV4Velocity",
 97    )
 98
 99    make_val_env = lambda: AMAGOEnvWithVelocityName(
100        MyCustomHalfCheetahEval(
101            task_min_velocity=args.task_min_velocity,
102            task_max_velocity=args.task_max_velocity,
103        ),
104        # this would get replaced by the env_name property
105        # defined above.
106        env_name=f"HalfCheetahV4VelocityEval",
107    )
108
109    config = {}
110    # switch sequence model
111    traj_encoder_type = switch_traj_encoder(
112        config,
113        arch=args.traj_encoder,
114        memory_size=args.memory_size,
115        layers=args.memory_layers,
116    )
117    # switch agent
118    agent_type = switch_agent(
119        config,
120        args.agent_type,
121        reward_multiplier=1.0,  # gym locomotion returns are already large
122        gamma=0.99,  # locomotion policies don't need long horizons - fall back to the default
123        tau=0.005,
124    )
125    # "egreedy" exploration in continuous control is just the epsilon-scheduled random (normal)
126    # noise from most TD3/DPPG implementations.
127    exploration_type = switch_exploration(config, "egreedy", steps_anneal=500_000)
128    use_config(config, args.configs)
129
130    group_name = args.run_name
131    for trial in range(args.trials):
132        run_name = group_name + f"_trial_{trial}"
133        experiment = create_experiment_from_cli(
134            args,
135            make_train_env=make_train_env,  # different train/val envs
136            make_val_env=make_val_env,
137            max_seq_len=args.policy_seq_len,
138            traj_save_len=args.policy_seq_len * 6,
139            run_name=run_name,
140            tstep_encoder_type=amago.nets.tstep_encoders.FFTstepEncoder,
141            traj_encoder_type=traj_encoder_type,
142            exploration_wrapper_type=exploration_type,
143            agent_type=agent_type,
144            group_name=group_name,
145            val_timesteps_per_epoch=args.eval_episodes_per_actor * 1001,
146            grad_clip=2.0,
147            learning_rate=3e-4,
148        )
149        experiment.start()
150        if args.ckpt is not None:
151            experiment.load_checkpoint(args.ckpt)
152        experiment.learn()
153        experiment.evaluate_test(make_val_env, timesteps=10_000, render=False)
154        experiment.delete_buffer_from_disk()
155        wandb.finish()