Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion traincheck/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def main():
trace_parent_folders = []
if args.traces is not None:
logger.info("Reading traces from %s", "\n".join(args.traces))
trace_parent_folders = [os.path.basename(os.path.commonpath(args.traces[0]))]
trace_parent_folders = [os.path.basename(os.path.commonpath(args.traces))]
traces.append(read_trace_file(args.traces))
if args.trace_folders is not None:
for trace_folder in args.trace_folders:
Expand Down
11 changes: 9 additions & 2 deletions traincheck/collect_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def main():
parser.add_argument(
"--model-tracker-style",
type=str,
choices=["sampler", "proxy"],
choices=["sampler", "proxy", "proxyparameter"],
default="proxy",
)
parser.add_argument(
Expand All @@ -371,6 +371,11 @@ def main():
action="store_true",
help="Disable automatic variable instrumentation, necessary when the default behavior of the instrumentor is not desired (e.g. cause segmentation fault)",
)
parser.add_argument(
"--use-torch-compile",
action="store_true",
help="Indicate wthether use torch.compile to speed the model, necessary to realize compatibility",
)

args = parser.parse_args()

Expand Down Expand Up @@ -444,7 +449,7 @@ def main():
scan_proxy_in_args = not args.disable_scan_proxy_in_args

# if no proxy tracking specified in the arguments, disable the scan_proxy_in_args
if not args.models_to_track or args.model_tracker_style != "proxy":
if not args.models_to_track or args.model_tracker_style == "sampler":
scan_proxy_in_args = False

if args.invariants:
Expand Down Expand Up @@ -481,6 +486,7 @@ def main():
output_dir=output_dir,
instr_descriptors=args.instr_descriptors,
no_auto_var_instr=args.no_auto_var_instr,
use_torch_compile=args.use_torch_compile,
)
else:
source_code = instrumentor.instrument_file(
Expand All @@ -496,6 +502,7 @@ def main():
output_dir=output_dir,
instr_descriptors=args.instr_descriptors,
no_auto_var_instr=args.no_auto_var_instr,
use_torch_compile=args.use_torch_compile,
)

if args.copy_all_files:
Expand Down
10 changes: 10 additions & 0 deletions traincheck/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def should_disable_proxy_dumping() -> bool:


INSTR_DESCRIPTORS = False
USE_TORCH_COMPILE = False

ALL_STAGE_NAMES = {
"init",
Expand All @@ -249,3 +250,12 @@ def should_disable_proxy_dumping() -> bool:
"preprocessing",
"postprocessing",
}

COMPILE_INTERNAL_MODULE = (
"torch.fx",
# "torch._dynamo",
"torch._inductor",
"torch._subclasses",
"torch._higher_order_ops",
"torch.utils._sympy",
)
85 changes: 76 additions & 9 deletions traincheck/instrumentor/dumper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@

# if torch.cuda.is_available():
from traincheck.proxy_wrapper.hash import tensor_hash
from traincheck.proxy_wrapper.proxy_basics import is_fake_tensor
from traincheck.proxy_wrapper.proxy_config import (
attribute_black_list,
primitive_types,
proxy_attribute,
tensor_dump_format,
)
from traincheck.utils import get_timestamp_ns, typename
from traincheck.utils import get_timestamp_ns, typename, typename_compile

DEBUG = os.environ.get("ML_DAIKON_DEBUG", False)
THREAD_DATA = threading.local()
Expand All @@ -44,12 +46,48 @@
logger = logging.getLogger(__name__)


def _json_default(o):
try:
if type(o).__name__ in ("SymInt", "SymFloat", "SymBool"):
return str(o)

if isinstance(o, torch.device):
return str(o)
if isinstance(o, torch.dtype):
return str(o)
if isinstance(o, torch.Size):
out = []
for d in o:
try:
out.append(int(d))
except Exception:
out.append(str(d))
return out
except Exception:
pass

if isinstance(o, set):
return list(o)
if isinstance(o, tuple):
return list(o)

try:
import numpy as np

if isinstance(o, (np.generic,)):
return o.item()
except Exception:
pass

return repr(o)


def serialize(obj_dict: dict[str, object | str]) -> str:
try:
return orjson.dumps(obj_dict).decode("utf-8")
return orjson.dumps(obj_dict, default=_json_default).decode("utf-8")
except Exception:
# if orjson fails (e.g. cannot handle ints larger than 64-bit), fallback to json
return json.dumps(obj_dict)
return json.dumps(obj_dict, default=_json_default)


def monitor_main_thread(main_thread, stop_event):
Expand Down Expand Up @@ -335,6 +373,9 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict
):
continue

if attr_name in proxy_attribute:
continue

if attr_name in attribute_black_list:
continue

Expand All @@ -346,12 +387,17 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict

attr = safe_getattr(var, attr_name)
if attr is NOT_FOUND:
logger.warning(
f"Failed to get attribute {attr_name} of object type {type(var)}, skipping it for all following dumps for this attribute."
)
if var_type not in skip_attrs_due_to_errs:
skip_attrs_due_to_errs[var_type] = set()
skip_attrs_due_to_errs[var_type].add(attr_name)
if not (
attr_name == "data"
and isinstance(var, torch.Tensor)
and not include_tensor_data
):
logger.warning(
f"Failed to get attribute {attr_name} of object type {type(var)}, skipping it for all following dumps for this attribute."
)
if var_type not in skip_attrs_due_to_errs:
skip_attrs_due_to_errs[var_type] = set()
skip_attrs_due_to_errs[var_type].add(attr_name)
continue

attr_name = str(attr_name)
Expand Down Expand Up @@ -395,7 +441,25 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict
return result


def convert_fake_tensor_to_dict(var):
try:
shape = tuple(var.shape)
except Exception:
shape = None
try:
dtype = str(var.dtype)
except Exception:
dtype = None
return {
"fake": True,
"shape": shape,
"dtype": dtype,
}


def obj_to_serializable(obj, dump_config=None) -> dict[str, object]:
if is_fake_tensor(obj):
return {typename_compile(obj): convert_fake_tensor_to_dict(obj)}
if (
type(obj) in skip_type_due_to_recursion
and skip_type_due_to_recursion[type(obj)] > RECURSION_ERR_THRESHOLD
Expand Down Expand Up @@ -429,6 +493,9 @@ def var_to_serializable(obj, dump_config=None) -> dict[str, object]:
If you want to dump the `data` attribute of a tensor, use `convert_var_to_dict` and set `include_tensor_data=True`.
"""

if is_fake_tensor(obj):
return {typename_compile(obj): convert_fake_tensor_to_dict(obj)}

if issubclass(type(obj), dict) and type(obj) != dict: # noqa E721
return obj_to_serializable(obj, dump_config=dump_config)

Expand Down
31 changes: 26 additions & 5 deletions traincheck/instrumentor/source_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def get_child_parent_map(root) -> dict[ast.AST, ast.AST]:


def instrument_all_model_assignments(
source_code: str, model_name: str, mode: str
source_code: str, model_name: str, mode: str | None
) -> str:
"""
Finds all assignment statements to `model` and inserts a Proxy statement or a VarSampler statement
Expand All @@ -292,8 +292,15 @@ def instrument_all_model_assignments(
instr_statement = ast.parse(
f"{model_name}_sampler = VarSampler({model_name}, var_name='{model_name}')"
)
elif mode == "proxyparameter":
instr_statement = ast.parse(
f"proxy_parameter({model_name}, logdir=proxy_config.proxy_log_dir, parent_name='{model_name}')"
)

else:
raise ValueError(f"Invalid mode: {mode}. Must be one of ['proxy', 'sampler']")
raise ValueError(
f"Invalid mode: {mode}. Must be one of ['proxy', 'sampler', 'proxyparameter']"
)

# find all assignment statements to `model`
assignments = []
Expand Down Expand Up @@ -348,6 +355,7 @@ def instrument_model_tracker_proxy(
models_to_track: list[str],
adjusted_proxy_config: list[dict[str, int | bool | str]],
no_auto_var_instr: bool,
model_tracker_style: str | None,
):
auto_observer_config: dict[str, int | bool | str] = adjusted_proxy_config[0]
proxy_basic_config: dict[str, int | bool | str] = adjusted_proxy_config[1]
Expand All @@ -373,8 +381,13 @@ def instrument_model_tracker_proxy(
tensor_dump_format.update({tensor_dump_format})
"""

proxy_start_code += """
if model_tracker_style == "proxy":
proxy_start_code += """
from traincheck.proxy_wrapper.proxy import Proxy
"""
else:
proxy_start_code += """
from traincheck.proxy_wrapper.subclass import proxy_parameter
"""

if auto_observer_config["enable_auto_observer"]:
Expand Down Expand Up @@ -435,7 +448,7 @@ def instrument_model_tracker_proxy(
if not no_auto_var_instr:
for model in models_to_track:
instrumented_source = instrument_all_model_assignments(
instrumented_source, model, "proxy"
instrumented_source, model, model_tracker_style
)

code_head, code_tail = get_code_head_and_tail(instrumented_source)
Expand Down Expand Up @@ -797,6 +810,7 @@ def instrument_file(
output_dir: str,
instr_descriptors: bool,
no_auto_var_instr: bool,
use_torch_compile: bool,
) -> str:
"""
Instruments the given file and returns the instrumented source code.
Expand Down Expand Up @@ -834,19 +848,26 @@ def instrument_file(
import traincheck.config.config as general_config
general_config.INSTR_DESCRIPTORS = {instr_descriptors}
"""
if use_torch_compile:
torch_compile_config_update = """
general_config.USE_TORCH_COMPILE = True
"""
general_config_update = general_config_update + torch_compile_config_update
# TODO: move the INSTR_DESCRIPTORS to the instr_opts file

if models_to_track:
assert model_tracker_style in [
"proxy",
"sampler",
"proxyparameter",
], f"Invalid model tracker style: {model_tracker_style}, must be one of ['proxy', 'sampler']"
if model_tracker_style == "proxy":
if model_tracker_style == "proxy" or model_tracker_style == "proxyparameter":
instrumented_source = instrument_model_tracker_proxy(
instrumented_source,
models_to_track,
adjusted_proxy_config,
no_auto_var_instr,
model_tracker_style,
)
else:
instrumented_source = instrument_model_tracker_sampler(
Expand Down
19 changes: 14 additions & 5 deletions traincheck/instrumentor/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
funcs_to_be_replaced,
is_funcs_to_be_unproxied,
)
from traincheck.proxy_wrapper.proxy_basics import is_proxied, unproxy_func
from traincheck.proxy_wrapper.proxy_basics import (
is_proxied,
is_proxyparameter,
unproxy_func,
)
from traincheck.proxy_wrapper.proxy_config import enable_C_level_observer
from traincheck.proxy_wrapper.proxy_registry import get_global_registry
from traincheck.utils import get_timestamp_ns, get_unique_id, typename
Expand Down Expand Up @@ -215,7 +219,7 @@ def global_wrapper(

def find_proxy_in_args(args):
for i, arg in enumerate(args):
if is_proxied(arg):
if is_proxied(arg) or is_proxyparameter(arg):
proxy_in_args.append(arg)
elif type(arg) in [list, tuple]:
find_proxy_in_args(arg)
Expand All @@ -234,9 +238,14 @@ def find_proxy_in_args(args):
if "proxy_obj_names" not in pre_record:
pre_record["proxy_obj_names"] = []
for proxy in proxy_in_args:
pre_record["proxy_obj_names"].append(
[proxy.__dict__["var_name"], type(proxy._obj).__name__]
)
if is_proxyparameter(proxy):
pre_record["proxy_obj_names"].append(
[proxy.__dict__["var_name"], "Parameter"]
)
else:
pre_record["proxy_obj_names"].append(
[proxy.__dict__["var_name"], type(proxy._obj).__name__]
)
if dump_args:
dict_args_kwargs = to_dict_args_kwargs(args, kwargs, dump_args_config)
pre_record["args"] = dict_args_kwargs["args"]
Expand Down
1 change: 0 additions & 1 deletion traincheck/invariant/precondition.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,6 @@ def find_precondition_from_single_group(
if len(example) == 0:
raise ValueError("Empty example found in positive examples")

# HACK: in ConsistencyRelation in order to avoid the field used in the invariant, we need to skip the field in the precondition. It is up to the caller to provide the keys to skip. We should try to refactor this to have a more generic solution.
earliest_time = example[0]["time"]
process_id = example[0]["process_id"]
thread_id = example[0]["thread_id"]
Expand Down
Loading
Loading