Skip to content

Commit c001440

Browse files
committed
Convert to decorator syntax
1 parent 2c8d03e commit c001440

File tree

1 file changed

+27
-21
lines changed

1 file changed

+27
-21
lines changed

switch_model/utilities/__init__.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77
from __future__ import print_function
88

9+
import functools
910
import os, types, sys, argparse, time, datetime, traceback, subprocess, platform
1011
import warnings
1112

@@ -882,35 +883,40 @@ def query_yes_no(question, default="yes"):
882883
sys.stdout.write("Please respond with 'yes' or 'no' "
883884
"(or 'y' or 'n').\n")
884885

885-
def catch_exceptions(func, *args, should_catch=True, warning_msg=None, **kwargs):
886-
"""Catches exceptions thrown by function."""
887-
if not should_catch:
888-
return func(*args, **kwargs)
889886

890-
try:
891-
return func(*args, **kwargs)
892-
except:
893-
if warning_msg is not None:
894-
warnings.warn(warning_msg)
887+
def catch_exceptions(warning_msg=None, should_catch=True):
888+
"""Decorator that catches exceptions."""
889+
890+
def decorator(func):
891+
@functools.wraps(func)
892+
def wrapper(*args, **kwargs):
893+
if not should_catch:
894+
return func(*args, **kwargs)
895+
896+
try:
897+
return func(*args, **kwargs)
898+
except:
899+
if warning_msg is not None:
900+
warnings.warn(warning_msg)
901+
902+
return wrapper
903+
904+
return decorator
895905

896906

897907
def run_command(command):
898908
return subprocess.check_output(command.split(" "), cwd=os.path.dirname(__file__)).strip().decode("UTF-8")
899909

900910

901-
def get_git_branch(warning_msg="Failed to get Git Branch."):
902-
return catch_exceptions(
903-
run_command,
904-
"git rev-parse --abbrev-ref HEAD",
905-
warning_msg=warning_msg
906-
)
911+
@catch_exceptions("Failed to get Git Branch.")
912+
def get_git_branch():
913+
return run_command("git rev-parse --abbrev-ref HEAD")
914+
915+
916+
@catch_exceptions("Failed to get Git Commit Hash.")
917+
def get_git_commit():
918+
return run_command("git rev-parse HEAD")
907919

908-
def get_git_commit(warning_msg="Failed to get Git Commit Hash."):
909-
return catch_exceptions(
910-
run_command,
911-
"git rev-parse HEAD",
912-
warning_msg=warning_msg
913-
)
914920

915921
def add_git_info():
916922
commit_num = get_git_commit()

0 commit comments

Comments
 (0)