File: optuna_profile.py

package info (click to toggle)
python-cmaes 0.12.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 544 kB
  • sloc: python: 5,136; sh: 88; makefile: 4
file content (41 lines) | stat: -rw-r--r-- 1,103 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import argparse
import cProfile
import logging
import pstats
import optuna

parser = argparse.ArgumentParser()
parser.add_argument("--storage", choices=["memory", "sqlite"], default="memory")
parser.add_argument("--params", type=int, default=100)
parser.add_argument("--trials", type=int, default=1000)
args = parser.parse_args()


def objective(trial: optuna.Trial):
    val = 0
    for i in range(args.params):
        xi = trial.suggest_uniform(str(i), -4, 4)
        val += (xi - 2) ** 2
    return val


def main():
    logging.disable(level=logging.INFO)
    storage = None
    if args.storage == "sqlite":
        storage = f"sqlite:///db-{args.trials}-{args.params}.sqlite3"
    sampler = optuna.samplers.CmaEsSampler()
    study = optuna.create_study(sampler=sampler, storage=storage)

    profiler = cProfile.Profile()
    profiler.runcall(
        study.optimize, objective, n_trials=args.trials, gc_after_trial=False
    )
    profiler.dump_stats("profile.stats")

    stats = pstats.Stats("profile.stats")
    stats.sort_stats("time").print_stats(5)


if __name__ == "__main__":
    main()