Basic Gymnasium

Basic Gymnasium#

 1from argparse import ArgumentParser
 2
 3import gymnasium as gym
 4import wandb
 5
 6import amago
 7from amago.envs import AMAGOEnv
 8from amago.cli_utils import *
 9
10
11def add_cli(parser):
12    parser.add_argument(
13        "--env", type=str, required=True, help="Environment name for `gym.make`"
14    )
15    parser.add_argument(
16        "--max_seq_len", type=int, default=128, help="Policy sequence length."
17    )
18    parser.add_argument(
19        "--eval_timesteps",
20        type=int,
21        default=1000,
22        help="Timesteps per actor per evaluation. Tune based on the episode length of the environment (to be at least one full episode).",
23    )
24    return parser
25
26
27if __name__ == "__main__":
28    parser = ArgumentParser()
29    add_common_cli(parser)
30    add_cli(parser)
31    args = parser.parse_args()
32
33    # setup environment
34    env_name = args.env.replace("/", "_")
35    make_train_env = lambda: AMAGOEnv(
36        gym.make(args.env),
37        env_name=env_name,
38    )
39    config = {
40        # dictionary that sets default value for kwargs of classes that are marked as `gin.configurable`
41        # see `tutorial.md` for more information. For example:
42        "amago.nets.policy_dists.Discrete.clip_prob_high": 1.0,
43        "amago.nets.policy_dists.Discrete.clip_prob_low": 1e-6,
44    }
45    # switch sequence model
46    traj_encoder_type = switch_traj_encoder(
47        config,
48        arch=args.traj_encoder,
49        memory_size=args.memory_size,
50        layers=args.memory_layers,
51    )
52    # switch agent
53    agent_type = switch_agent(
54        config,
55        args.agent_type,
56        reward_multiplier=1.0,
57    )
58    use_config(config, args.configs)
59
60    group_name = f"{args.run_name}_{env_name}"
61    for trial in range(args.trials):
62        run_name = group_name + f"_trial_{trial}"
63        experiment = create_experiment_from_cli(
64            args,
65            make_train_env=make_train_env,
66            make_val_env=make_train_env,
67            max_seq_len=args.max_seq_len,
68            traj_save_len=args.max_seq_len * 8,
69            run_name=run_name,
70            tstep_encoder_type=amago.nets.tstep_encoders.FFTstepEncoder,
71            traj_encoder_type=traj_encoder_type,
72            agent_type=agent_type,
73            group_name=group_name,
74            val_timesteps_per_epoch=args.eval_timesteps,
75        )
76        experiment = switch_async_mode(experiment, args.mode)
77        experiment.start()
78        if args.ckpt is not None:
79            experiment.load_checkpoint(args.ckpt)
80        experiment.learn()
81        experiment.evaluate_test(make_train_env, timesteps=10_000, render=False)
82        experiment.delete_buffer_from_disk()
83        wandb.finish()