Skip to content

Commit 3dbdadc

Browse files
committed
Convert to decorator syntax
1 parent 3b7a534 commit 3dbdadc

File tree

1 file changed

+27
-21
lines changed

1 file changed

+27
-21
lines changed

switch_model/utilities/__init__.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77
from __future__ import print_function
88

9+
import functools
910
import os, types, sys, argparse, time, datetime, traceback, subprocess, platform
1011
import warnings
1112

@@ -951,16 +952,25 @@ def query_yes_no(question, default="yes"):
951952
else:
952953
sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n")
953954

954-
def catch_exceptions(func, *args, should_catch=True, warning_msg=None, **kwargs):
955-
"""Catches exceptions thrown by function."""
956-
if not should_catch:
957-
return func(*args, **kwargs)
958955

959-
try:
960-
return func(*args, **kwargs)
961-
except:
962-
if warning_msg is not None:
963-
warnings.warn(warning_msg)
956+
def catch_exceptions(warning_msg=None, should_catch=True):
957+
"""Decorator that catches exceptions."""
958+
959+
def decorator(func):
960+
@functools.wraps(func)
961+
def wrapper(*args, **kwargs):
962+
if not should_catch:
963+
return func(*args, **kwargs)
964+
965+
try:
966+
return func(*args, **kwargs)
967+
except:
968+
if warning_msg is not None:
969+
warnings.warn(warning_msg)
970+
971+
return wrapper
972+
973+
return decorator
964974

965975

966976
def run_command(command):
@@ -972,19 +982,15 @@ def run_command(command):
972982

973983

974984

975-
def get_git_branch(warning_msg="Failed to get Git Branch."):
976-
return catch_exceptions(
977-
run_command,
978-
"git rev-parse --abbrev-ref HEAD",
979-
warning_msg=warning_msg
980-
)
985+
@catch_exceptions("Failed to get Git Branch.")
986+
def get_git_branch():
987+
return run_command("git rev-parse --abbrev-ref HEAD")
988+
989+
990+
@catch_exceptions("Failed to get Git Commit Hash.")
991+
def get_git_commit():
992+
return run_command("git rev-parse HEAD")
981993

982-
def get_git_commit(warning_msg="Failed to get Git Commit Hash."):
983-
return catch_exceptions(
984-
run_command,
985-
"git rev-parse HEAD",
986-
warning_msg=warning_msg
987-
)
988994

989995
def add_git_info():
990996
commit_num = get_git_commit()

0 commit comments

Comments
 (0)