|
4 | 4 | """Script to handle switch <cmd> calls from the command line.""" |
5 | 5 | from __future__ import print_function |
6 | 6 |
|
7 | | -import sys, os |
| 7 | +import argparse |
| 8 | +import importlib |
| 9 | +import sys |
8 | 10 | import switch_model |
| 11 | +from switch_model.utilities import get_git_branch |
| 12 | + |
| 13 | +def version(): |
| 14 | + print("Switch model version " + switch_model.__version__) |
| 15 | + branch = get_git_branch() |
| 16 | + if branch is not None: |
| 17 | + print(f"Switch Git branch: {branch}") |
| 18 | + return 0 |
| 19 | + |
| 20 | +def help_text(): |
| 21 | + print( |
| 22 | + f"Must specifiy one of the following commands: {list(cmds.keys())}.\nE.g. Run 'switch solve' or 'switch get_inputs'.") |
| 23 | + |
| 24 | + |
| 25 | +def get_module_runner(module): |
| 26 | + def runner(): |
| 27 | + importlib.import_module(module).main() |
| 28 | + return runner |
| 29 | + |
| 30 | + |
| 31 | +cmds = { |
| 32 | + "solve": get_module_runner("switch_model.solve"), |
| 33 | + "solve-scenarios": get_module_runner("switch_model.solve_scenarios"), |
| 34 | + "test": get_module_runner("switch_model.test"), |
| 35 | + "upgrade": get_module_runner("switch_model.upgrade"), |
| 36 | + "get_inputs": get_module_runner("switch_model.wecc.get_inputs"), |
| 37 | + "drop": get_module_runner("switch_model.tools.drop"), |
| 38 | + "new": get_module_runner("switch_model.tools.new"), |
| 39 | + "graph": get_module_runner("switch_model.tools.graph.cli_graph"), |
| 40 | + "compare": get_module_runner("switch_model.tools.graph.cli_compare"), |
| 41 | + "db": get_module_runner("switch_model.wecc.__main__"), |
| 42 | + "help": help_text |
| 43 | +} |
9 | 44 |
|
10 | 45 |
|
11 | 46 | def main(): |
12 | | - # TODO make a proper command line tool with help information for each option |
13 | | - cmds = [ |
14 | | - "solve", |
15 | | - "solve-scenarios", |
16 | | - "test", |
17 | | - "upgrade", |
18 | | - "get_inputs", |
19 | | - "--version", |
20 | | - "drop", |
21 | | - "new", |
22 | | - "graph", |
23 | | - "compare", |
24 | | - "sampling", |
25 | | - ] |
26 | | - if len(sys.argv) >= 2 and sys.argv[1] in cmds: |
27 | | - # If users run a script from the command line, the location of the script |
28 | | - # gets added to the start of sys.path; if they call a module from the |
29 | | - # command line then an empty entry gets added to the start of the path, |
30 | | - # indicating the current working directory. This module is often called |
31 | | - # from a command-line script, but we want the current working |
32 | | - # directory in the path because users may try to load local modules via |
33 | | - # the configuration files, so we make sure that's always in the path. |
34 | | - sys.path[0] = "" |
35 | | - |
36 | | - # adjust the argument list to make it look like someone ran "python -m <module>" directly |
37 | | - cmd = sys.argv[1] |
38 | | - sys.argv[0] += " " + cmd |
| 47 | + parser = argparse.ArgumentParser(add_help=False) |
| 48 | + parser.add_argument("--version", default=False, action="store_true", help="Get version info") |
| 49 | + parser.add_argument("subcommand", choices=cmds.keys(), help="The possible switch subcommands", nargs="?", |
| 50 | + default="help") |
| 51 | + |
| 52 | + # If users run a script from the command line, the location of the script |
| 53 | + # gets added to the start of sys.path; if they call a module from the |
| 54 | + # command line then an empty entry gets added to the start of the path, |
| 55 | + # indicating the current working directory. This module is often called |
| 56 | + # from a command-line script, but we want the current working |
| 57 | + # directory in the path because users may try to load local modules via |
| 58 | + # the configuration files, so we make sure that's always in the path. |
| 59 | + sys.path[0] = "" |
| 60 | + |
| 61 | + args, remaining_args = parser.parse_known_args() |
| 62 | + |
| 63 | + if args.version: |
| 64 | + return version() |
| 65 | + |
| 66 | + # adjust the argument list to make it look like someone ran "python -m <module>" directly |
| 67 | + if len(sys.argv) > 1: |
| 68 | + sys.argv[0] += " " + sys.argv[1] |
39 | 69 | del sys.argv[1] |
40 | | - if cmd == "--version": |
41 | | - print("Switch model version " + switch_model.__version__) |
42 | | - from switch_model.utilities import get_git_branch |
43 | | - branch = get_git_branch() |
44 | | - if branch is not None: |
45 | | - print(f"Switch Git branch: {branch}") |
46 | | - return 0 |
47 | | - if cmd == "solve": |
48 | | - from switch_model.solve import main |
49 | | - elif cmd == "solve-scenarios": |
50 | | - from switch_model.solve_scenarios import main |
51 | | - elif cmd == "test": |
52 | | - from switch_model.test import main |
53 | | - elif cmd == "upgrade": |
54 | | - from switch_model.upgrade import main |
55 | | - elif cmd == "get_inputs": |
56 | | - from switch_model.wecc.get_inputs import main |
57 | | - elif cmd == "sampling": |
58 | | - from switch_model.wecc.sampling import main |
59 | | - elif cmd == "drop": |
60 | | - from switch_model.tools.drop import main |
61 | | - elif cmd == "new": |
62 | | - from switch_model.tools.new import main |
63 | | - elif cmd == "graph": |
64 | | - from switch_model.tools.graph.cli_graph import main |
65 | | - elif cmd == "compare": |
66 | | - from switch_model.tools.graph.cli_compare import main |
67 | | - main() |
68 | | - else: |
69 | | - print( |
70 | | - "Usage: {} {{{}}} ...".format( |
71 | | - os.path.basename(sys.argv[0]), ", ".join(cmds) |
72 | | - ) |
73 | | - ) |
74 | | - print("Use one of these commands with --help for more information.") |
| 70 | + cmds[args.subcommand]() |
75 | 71 |
|
76 | 72 |
|
77 | 73 | if __name__ == "__main__": |
|
0 commit comments