|
19 | 19 | import numpy as np |
20 | 20 | import os |
21 | 21 | import tensor_comprehensions as tc |
| 22 | +import torch |
22 | 23 |
|
23 | 24 | from caffe2.python import core, dyndep, workspace, utils |
24 | 25 |
|
@@ -85,18 +86,22 @@ def main(): |
85 | 86 |
|
86 | 87 | @utils.debug |
87 | 88 | def tune(args): |
88 | | - fc = tc.define(FC_LANG, name="func_fc") |
89 | | - options = fc.autotune( |
90 | | - (args.batch_size, args.input_dim), |
91 | | - (args.output_dim, args.input_dim), |
92 | | - (args.output_dim,), |
93 | | - cache = args.tuner_cache_file, |
94 | | - threads = args.tuner_threads, |
95 | | - generations = args.tuner_gen_generations, |
96 | | - pop_size = args.tuner_gen_pop_size, |
97 | | - ) |
98 | | - print(options.toString()) |
99 | | - return options |
| 89 | + tuner_config = ( |
| 90 | + tc.TunerConfig() |
| 91 | + .generations(args.tuner_gen_generations) |
| 92 | + .devices(args.tuner_devices) |
| 93 | + .threads(args.tuner_threads) |
| 94 | + .pop_size(args.tuner_gen_pop_size)) |
| 95 | + return tc.autotune( |
| 96 | + FC_LANG, |
| 97 | + 'func_fc', |
| 98 | + torch.randn(args.batch_size, args.input_dim, device='cuda'), |
| 99 | + torch.randn(args.output_dim, args.input_dim, device='cuda'), |
| 100 | + torch.randn(args.output_dim, device='cuda'), |
| 101 | + starting_options = tc.MappingOptions('naive'), |
| 102 | + tuner_config = tuner_config, |
| 103 | + cache_filename = args.tuner_cache_file, |
| 104 | + store_to_cache = True) |
100 | 105 |
|
101 | 106 |
|
102 | 107 | @utils.debug |
|
0 commit comments