1from argparse import ArgumentParser
2
3import wandb
4
5from amago.envs import AMAGOEnv
6from amago.envs.builtin.alchemy import SymbolicAlchemy
7from amago import cli_utils
8
9
10if __name__ == "__main__":
11 parser = ArgumentParser()
12 cli_utils.add_common_cli(parser)
13 args = parser.parse_args()
14
15 config = {}
16 traj_encoder_type = cli_utils.switch_traj_encoder(
17 config,
18 arch=args.traj_encoder,
19 memory_size=args.memory_size,
20 layers=args.memory_layers,
21 )
22 exploration_wrapper_type = cli_utils.switch_exploration(
23 config, "bilevel", rollout_horizon=200, steps_anneal=2_500_000
24 )
25 agent_type = cli_utils.switch_agent(
26 config, args.agent_type, reward_multiplier=100.0
27 )
28 tstep_encoder_type = cli_utils.switch_tstep_encoder(
29 config, arch="ff", n_layers=2, d_hidden=256, d_output=256
30 )
31
32 cli_utils.use_config(config, args.configs)
33 make_train_env = lambda: AMAGOEnv(
34 env=SymbolicAlchemy(),
35 env_name="dm_symbolic_alchemy",
36 )
37 group_name = f"{args.run_name}_symbolic_dm_alchemy"
38 for trial in range(args.trials):
39 run_name = group_name + f"_trial_{trial}"
40 experiment = cli_utils.create_experiment_from_cli(
41 args,
42 make_train_env=make_train_env,
43 make_val_env=make_train_env,
44 max_seq_len=201,
45 traj_save_len=201,
46 group_name=group_name,
47 run_name=run_name,
48 tstep_encoder_type=tstep_encoder_type,
49 traj_encoder_type=traj_encoder_type,
50 exploration_wrapper_type=exploration_wrapper_type,
51 agent_type=agent_type,
52 val_timesteps_per_epoch=2000,
53 )
54 experiment = cli_utils.switch_async_mode(experiment, args.mode)
55 experiment.start()
56 if args.ckpt is not None:
57 experiment.load_checkpoint(args.ckpt)
58 experiment.learn()
59 experiment.evaluate_test(make_train_env, timesteps=20_000, render=False)
60 experiment.delete_buffer_from_disk()
61 wandb.finish()