diff --git a/traincheck/checker.py b/traincheck/checker.py index 62045468..dd8816f5 100644 --- a/traincheck/checker.py +++ b/traincheck/checker.py @@ -153,7 +153,7 @@ def main(): trace_parent_folders = [] if args.traces is not None: logger.info("Reading traces from %s", "\n".join(args.traces)) - trace_parent_folders = [os.path.basename(os.path.commonpath(args.traces[0]))] + trace_parent_folders = [os.path.basename(os.path.commonpath(args.traces))] traces.append(read_trace_file(args.traces)) if args.trace_folders is not None: for trace_folder in args.trace_folders: diff --git a/traincheck/collect_trace.py b/traincheck/collect_trace.py index 48bcfe87..b128e085 100644 --- a/traincheck/collect_trace.py +++ b/traincheck/collect_trace.py @@ -350,7 +350,7 @@ def main(): parser.add_argument( "--model-tracker-style", type=str, - choices=["sampler", "proxy"], + choices=["sampler", "proxy", "proxyparameter"], default="proxy", ) parser.add_argument( @@ -371,6 +371,11 @@ def main(): action="store_true", help="Disable automatic variable instrumentation, necessary when the default behavior of the instrumentor is not desired (e.g. cause segmentation fault)", ) + parser.add_argument( + "--use-torch-compile", + action="store_true", + help="Indicate wthether use torch.compile to speed the model, necessary to realize compatibility", + ) args = parser.parse_args() @@ -444,7 +449,7 @@ def main(): scan_proxy_in_args = not args.disable_scan_proxy_in_args # if no proxy tracking specified in the arguments, disable the scan_proxy_in_args - if not args.models_to_track or args.model_tracker_style != "proxy": + if not args.models_to_track or args.model_tracker_style == "sampler": scan_proxy_in_args = False if args.invariants: @@ -481,6 +486,7 @@ def main(): output_dir=output_dir, instr_descriptors=args.instr_descriptors, no_auto_var_instr=args.no_auto_var_instr, + use_torch_compile=args.use_torch_compile, ) else: source_code = instrumentor.instrument_file( @@ -496,6 +502,7 @@ def main(): output_dir=output_dir, instr_descriptors=args.instr_descriptors, no_auto_var_instr=args.no_auto_var_instr, + use_torch_compile=args.use_torch_compile, ) if args.copy_all_files: diff --git a/traincheck/config/config.py b/traincheck/config/config.py index 51c457af..55a6295d 100644 --- a/traincheck/config/config.py +++ b/traincheck/config/config.py @@ -238,6 +238,7 @@ def should_disable_proxy_dumping() -> bool: INSTR_DESCRIPTORS = False +USE_TORCH_COMPILE = False ALL_STAGE_NAMES = { "init", @@ -249,3 +250,12 @@ def should_disable_proxy_dumping() -> bool: "preprocessing", "postprocessing", } + +COMPILE_INTERNAL_MODULE = ( + "torch.fx", + # "torch._dynamo", + "torch._inductor", + "torch._subclasses", + "torch._higher_order_ops", + "torch.utils._sympy", +) diff --git a/traincheck/instrumentor/dumper.py b/traincheck/instrumentor/dumper.py index f6bf03fc..04935e8a 100644 --- a/traincheck/instrumentor/dumper.py +++ b/traincheck/instrumentor/dumper.py @@ -18,12 +18,14 @@ # if torch.cuda.is_available(): from traincheck.proxy_wrapper.hash import tensor_hash +from traincheck.proxy_wrapper.proxy_basics import is_fake_tensor from traincheck.proxy_wrapper.proxy_config import ( attribute_black_list, primitive_types, + proxy_attribute, tensor_dump_format, ) -from traincheck.utils import get_timestamp_ns, typename +from traincheck.utils import get_timestamp_ns, typename, typename_compile DEBUG = os.environ.get("ML_DAIKON_DEBUG", False) THREAD_DATA = threading.local() @@ -44,12 +46,48 @@ logger = logging.getLogger(__name__) +def _json_default(o): + try: + if type(o).__name__ in ("SymInt", "SymFloat", "SymBool"): + return str(o) + + if isinstance(o, torch.device): + return str(o) + if isinstance(o, torch.dtype): + return str(o) + if isinstance(o, torch.Size): + out = [] + for d in o: + try: + out.append(int(d)) + except Exception: + out.append(str(d)) + return out + except Exception: + pass + + if isinstance(o, set): + return list(o) + if isinstance(o, tuple): + return list(o) + + try: + import numpy as np + + if isinstance(o, (np.generic,)): + return o.item() + except Exception: + pass + + return repr(o) + + def serialize(obj_dict: dict[str, object | str]) -> str: try: - return orjson.dumps(obj_dict).decode("utf-8") + return orjson.dumps(obj_dict, default=_json_default).decode("utf-8") except Exception: # if orjson fails (e.g. cannot handle ints larger than 64-bit), fallback to json - return json.dumps(obj_dict) + return json.dumps(obj_dict, default=_json_default) def monitor_main_thread(main_thread, stop_event): @@ -335,6 +373,9 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict ): continue + if attr_name in proxy_attribute: + continue + if attr_name in attribute_black_list: continue @@ -346,12 +387,17 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict attr = safe_getattr(var, attr_name) if attr is NOT_FOUND: - logger.warning( - f"Failed to get attribute {attr_name} of object type {type(var)}, skipping it for all following dumps for this attribute." - ) - if var_type not in skip_attrs_due_to_errs: - skip_attrs_due_to_errs[var_type] = set() - skip_attrs_due_to_errs[var_type].add(attr_name) + if not ( + attr_name == "data" + and isinstance(var, torch.Tensor) + and not include_tensor_data + ): + logger.warning( + f"Failed to get attribute {attr_name} of object type {type(var)}, skipping it for all following dumps for this attribute." + ) + if var_type not in skip_attrs_due_to_errs: + skip_attrs_due_to_errs[var_type] = set() + skip_attrs_due_to_errs[var_type].add(attr_name) continue attr_name = str(attr_name) @@ -395,7 +441,25 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict return result +def convert_fake_tensor_to_dict(var): + try: + shape = tuple(var.shape) + except Exception: + shape = None + try: + dtype = str(var.dtype) + except Exception: + dtype = None + return { + "fake": True, + "shape": shape, + "dtype": dtype, + } + + def obj_to_serializable(obj, dump_config=None) -> dict[str, object]: + if is_fake_tensor(obj): + return {typename_compile(obj): convert_fake_tensor_to_dict(obj)} if ( type(obj) in skip_type_due_to_recursion and skip_type_due_to_recursion[type(obj)] > RECURSION_ERR_THRESHOLD @@ -429,6 +493,9 @@ def var_to_serializable(obj, dump_config=None) -> dict[str, object]: If you want to dump the `data` attribute of a tensor, use `convert_var_to_dict` and set `include_tensor_data=True`. """ + if is_fake_tensor(obj): + return {typename_compile(obj): convert_fake_tensor_to_dict(obj)} + if issubclass(type(obj), dict) and type(obj) != dict: # noqa E721 return obj_to_serializable(obj, dump_config=dump_config) diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index 4de57416..9f2bcb43 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -271,7 +271,7 @@ def get_child_parent_map(root) -> dict[ast.AST, ast.AST]: def instrument_all_model_assignments( - source_code: str, model_name: str, mode: str + source_code: str, model_name: str, mode: str | None ) -> str: """ Finds all assignment statements to `model` and inserts a Proxy statement or a VarSampler statement @@ -292,8 +292,15 @@ def instrument_all_model_assignments( instr_statement = ast.parse( f"{model_name}_sampler = VarSampler({model_name}, var_name='{model_name}')" ) + elif mode == "proxyparameter": + instr_statement = ast.parse( + f"proxy_parameter({model_name}, logdir=proxy_config.proxy_log_dir, parent_name='{model_name}')" + ) + else: - raise ValueError(f"Invalid mode: {mode}. Must be one of ['proxy', 'sampler']") + raise ValueError( + f"Invalid mode: {mode}. Must be one of ['proxy', 'sampler', 'proxyparameter']" + ) # find all assignment statements to `model` assignments = [] @@ -348,6 +355,7 @@ def instrument_model_tracker_proxy( models_to_track: list[str], adjusted_proxy_config: list[dict[str, int | bool | str]], no_auto_var_instr: bool, + model_tracker_style: str | None, ): auto_observer_config: dict[str, int | bool | str] = adjusted_proxy_config[0] proxy_basic_config: dict[str, int | bool | str] = adjusted_proxy_config[1] @@ -373,8 +381,13 @@ def instrument_model_tracker_proxy( tensor_dump_format.update({tensor_dump_format}) """ - proxy_start_code += """ + if model_tracker_style == "proxy": + proxy_start_code += """ from traincheck.proxy_wrapper.proxy import Proxy +""" + else: + proxy_start_code += """ +from traincheck.proxy_wrapper.subclass import proxy_parameter """ if auto_observer_config["enable_auto_observer"]: @@ -435,7 +448,7 @@ def instrument_model_tracker_proxy( if not no_auto_var_instr: for model in models_to_track: instrumented_source = instrument_all_model_assignments( - instrumented_source, model, "proxy" + instrumented_source, model, model_tracker_style ) code_head, code_tail = get_code_head_and_tail(instrumented_source) @@ -797,6 +810,7 @@ def instrument_file( output_dir: str, instr_descriptors: bool, no_auto_var_instr: bool, + use_torch_compile: bool, ) -> str: """ Instruments the given file and returns the instrumented source code. @@ -834,19 +848,26 @@ def instrument_file( import traincheck.config.config as general_config general_config.INSTR_DESCRIPTORS = {instr_descriptors} """ + if use_torch_compile: + torch_compile_config_update = """ +general_config.USE_TORCH_COMPILE = True +""" + general_config_update = general_config_update + torch_compile_config_update # TODO: move the INSTR_DESCRIPTORS to the instr_opts file if models_to_track: assert model_tracker_style in [ "proxy", "sampler", + "proxyparameter", ], f"Invalid model tracker style: {model_tracker_style}, must be one of ['proxy', 'sampler']" - if model_tracker_style == "proxy": + if model_tracker_style == "proxy" or model_tracker_style == "proxyparameter": instrumented_source = instrument_model_tracker_proxy( instrumented_source, models_to_track, adjusted_proxy_config, no_auto_var_instr, + model_tracker_style, ) else: instrumented_source = instrument_model_tracker_sampler( diff --git a/traincheck/instrumentor/tracer.py b/traincheck/instrumentor/tracer.py index cf28785a..a4812b7d 100644 --- a/traincheck/instrumentor/tracer.py +++ b/traincheck/instrumentor/tracer.py @@ -29,7 +29,11 @@ funcs_to_be_replaced, is_funcs_to_be_unproxied, ) -from traincheck.proxy_wrapper.proxy_basics import is_proxied, unproxy_func +from traincheck.proxy_wrapper.proxy_basics import ( + is_proxied, + is_proxyparameter, + unproxy_func, +) from traincheck.proxy_wrapper.proxy_config import enable_C_level_observer from traincheck.proxy_wrapper.proxy_registry import get_global_registry from traincheck.utils import get_timestamp_ns, get_unique_id, typename @@ -215,7 +219,7 @@ def global_wrapper( def find_proxy_in_args(args): for i, arg in enumerate(args): - if is_proxied(arg): + if is_proxied(arg) or is_proxyparameter(arg): proxy_in_args.append(arg) elif type(arg) in [list, tuple]: find_proxy_in_args(arg) @@ -234,9 +238,14 @@ def find_proxy_in_args(args): if "proxy_obj_names" not in pre_record: pre_record["proxy_obj_names"] = [] for proxy in proxy_in_args: - pre_record["proxy_obj_names"].append( - [proxy.__dict__["var_name"], type(proxy._obj).__name__] - ) + if is_proxyparameter(proxy): + pre_record["proxy_obj_names"].append( + [proxy.__dict__["var_name"], "Parameter"] + ) + else: + pre_record["proxy_obj_names"].append( + [proxy.__dict__["var_name"], type(proxy._obj).__name__] + ) if dump_args: dict_args_kwargs = to_dict_args_kwargs(args, kwargs, dump_args_config) pre_record["args"] = dict_args_kwargs["args"] diff --git a/traincheck/invariant/precondition.py b/traincheck/invariant/precondition.py index ad507dcb..b2040d83 100644 --- a/traincheck/invariant/precondition.py +++ b/traincheck/invariant/precondition.py @@ -537,7 +537,6 @@ def find_precondition_from_single_group( if len(example) == 0: raise ValueError("Empty example found in positive examples") - # HACK: in ConsistencyRelation in order to avoid the field used in the invariant, we need to skip the field in the precondition. It is up to the caller to provide the keys to skip. We should try to refactor this to have a more generic solution. earliest_time = example[0]["time"] process_id = example[0]["process_id"] thread_id = example[0]["thread_id"] diff --git a/traincheck/proxy_wrapper/proxy_basics.py b/traincheck/proxy_wrapper/proxy_basics.py index 11f8162b..dd3014bb 100644 --- a/traincheck/proxy_wrapper/proxy_basics.py +++ b/traincheck/proxy_wrapper/proxy_basics.py @@ -4,9 +4,52 @@ import astor +import traincheck.config.config as config + + +def is_compile_internal_module(obj): + mod = getattr(type(obj), "__module__", "") or "" + if any(mod.startswith(p) for p in config.COMPILE_INTERNAL_MODULE): + return True + name = type(obj).__name__ + if mod.startswith("torch._dynamo") and name != "OptimizedModule": + return True + return False + + +def is_fake_tensor(x) -> bool: + if not config.USE_TORCH_COMPILE: + return False + try: + from torch._subclasses.fake_tensor import FakeTensor + from torch.fx import Proxy as FxProxy + + if isinstance(x, FakeTensor): + return True + if isinstance(x, FxProxy): + return True + except Exception: + pass + + try: + if is_compile_internal_module(x): + return True + except Exception: + return True + + try: + if x.device.type == "meta": + return True + except Exception: + return True + + return False + def is_proxied(obj): try: + if is_fake_tensor(obj): + return False if obj is not None and "is_traincheck_proxied_obj" in obj.__dict__: return True except Exception: @@ -14,6 +57,17 @@ def is_proxied(obj): return False +def is_proxyparameter(obj): + try: + if is_fake_tensor(obj): + return False + if obj is not None and "is_traincheck_proxyparameter" in obj.__dict__: + return True + except Exception: + return False + return False + + def unproxy_arg(arg, inspect_torch_module=False): if is_proxied(arg): diff --git a/traincheck/proxy_wrapper/proxy_config.py b/traincheck/proxy_wrapper/proxy_config.py index 57c2d4d1..66ce6d7c 100644 --- a/traincheck/proxy_wrapper/proxy_config.py +++ b/traincheck/proxy_wrapper/proxy_config.py @@ -49,3 +49,14 @@ "real", ] attribute_black_list = tensor_attribute_black_list +# TODO +proxy_attribute = [ + "process_id", + "thread_id", + "logdir", + "log_level", + "loglevel", + "is_traincheck_proxyparameter", + "var_name", + "last_update_timestamp", +] diff --git a/traincheck/proxy_wrapper/proxy_observer.py b/traincheck/proxy_wrapper/proxy_observer.py index 06afcc2b..5316fed6 100644 --- a/traincheck/proxy_wrapper/proxy_observer.py +++ b/traincheck/proxy_wrapper/proxy_observer.py @@ -2,18 +2,22 @@ import typing from traincheck.config.config import should_disable_proxy_dumping +from traincheck.proxy_wrapper.subclass import ProxyParameter from traincheck.utils import typename if typing.TYPE_CHECKING: from traincheck.proxy_wrapper.proxy import Proxy -from .proxy_basics import is_proxied, unproxy_func + from traincheck.proxy_wrapper.subclass import ProxyParameter + +from .proxy_basics import is_proxied, is_proxyparameter, unproxy_func def observe_proxy_var( - var: "Proxy", + var: typing.Union["Proxy", "ProxyParameter"], phase, observe_api_name: str, ): + # update the proxy object's timestamp var.update_timestamp() @@ -37,9 +41,9 @@ def wrapper(*args, **kwargs): # if the arg is list or tuple, check if it contains proxied object if type(arg) in [list, tuple]: for element in arg: - if is_proxied(element): + if is_proxied(element) or is_proxyparameter(element): proxied_vars.append(element) - if is_proxied(arg): + if is_proxied(arg) or is_proxyparameter(arg): proxied_vars.append(arg) # pre observe diff --git a/traincheck/proxy_wrapper/subclass.py b/traincheck/proxy_wrapper/subclass.py new file mode 100644 index 00000000..f335ba50 --- /dev/null +++ b/traincheck/proxy_wrapper/subclass.py @@ -0,0 +1,235 @@ +import logging +import os +import threading + +import torch +from torch import nn + +from traincheck.instrumentor.dumper import dump_trace_VAR +from traincheck.instrumentor.tracer import TraceLineType +from traincheck.proxy_wrapper.dumper import dump_attributes, get_meta_vars +from traincheck.utils import get_timestamp_ns + +from .proxy_basics import is_fake_tensor + +# from .proxy_registry import get_global_registry +# from .utils import print_debug + + +def in_dynamo() -> bool: + try: + import torch._dynamo as dynamo + + return bool(dynamo.is_compiling()) + except Exception: + return False + + +class ProxyParameter(torch.nn.Parameter): + loglevel = logging.INFO + + def __new__( + cls, + data, + logdir="proxy_log.log", + log_level=logging.INFO, + # TODO + # recurse=False, + var_name="", + should_dump_trace=True, + from_call=False, + from_iter=False, + # TODO + # from_copy=False, + ): + if isinstance(data, ProxyParameter): + return data + + if in_dynamo() or is_fake_tensor(data): + # we do not proxy the parameter if we are in dynamo or the tensor is a fake tensor + if isinstance(data, nn.Parameter): + return data + return nn.Parameter(data, requires_grad=data.requires_grad) + + requires_grad = getattr(data, "requires_grad", False) + tensor_grad = getattr(data, "grad", None) + + # When wrapping an existing Parameter we need to preserve any Python level + # attributes (e.g. hooks, user defined flags, ``grad``) so that the proxy + # behaves identically to the original parameter. ``Parameter.__new__`` + # returns a fresh instance, so we snapshot the metadata from ``data`` and + # replay it on the new ProxyParameter via the base Tensor ``__setattr__`` + # to avoid triggering the logging logic implemented in this class. + snapshot: dict = {} + + if isinstance(data, nn.Parameter): + snapshot = dict(getattr(data, "__dict__", {})) + base_tensor = data.detach() + elif isinstance(data, torch.Tensor): + base_tensor = data.detach() + else: + base_tensor = torch.as_tensor(data) + + proxied = super().__new__(cls, base_tensor, requires_grad=requires_grad) + + if snapshot: + tensor_setattr = torch.Tensor.__setattr__ + for name, value in snapshot.items(): + if name == "grad": + continue + try: + tensor_setattr(proxied, name, value) + except AttributeError: + # Some slots (e.g. torch internals) are read-only; skip them. + continue + + if tensor_grad is not None: + torch.Tensor.__setattr__(proxied, "grad", tensor_grad) + + return proxied + + def __init__( + self, + data, + logdir="proxy_log.log", + log_level=logging.INFO, + # TODO + # recurse=False, + var_name="", + should_dump_trace=True, + from_call=False, + from_iter=False, + # TODO + # from_copy=False, + ): + super().__init__() + # Access proxy attribute: since we are wrapping the getattr method, we need to access the attribute directly + self.__dict__["process_id"] = os.getpid() + self.__dict__["thread_id"] = threading.current_thread().ident + self.__dict__["logdir"] = logdir + self.__dict__["log_level"] = log_level + # TODO + # self.__dict__["meta_vars"] = {} + # self.__dict__["is_traincheck_proxied_obj"] = True + self.__dict__["is_traincheck_proxyparameter"] = True + # TODO + # self.__dict__["recurse"] = recurse + self.__dict__["var_name"] = var_name + # TODO + # self.__dict__["old_value"] = None + # self.__dict__["old_meta_vars"] = None + + current_time = get_timestamp_ns() + + self.__dict__["last_update_timestamp"] = current_time + + # print(f"init: {self.var_name}") + if should_dump_trace: + if from_call: + phase = "call" + + if from_iter: + phase = "iter" + # if the object is generated from getattr, then do not dump it + else: + phase = "update" + self.dump_trace(phase=phase, dump_loc="initing") + + def __setattr__(self, name, value): + # print(f"paremeter: {self.var_name}, name = {name}, value = {value}") + super().__setattr__(name, value) + self.update_timestamp() + self.dump_trace( + phase="update", + dump_loc=f"__setattr__ (attribute '{name}')", + ) + + def __deepcopy__(self, memo): + data = self.data + if in_dynamo() or is_fake_tensor(self): + return self + return type(self)( + data.clone(memory_format=torch.preserve_format), + var_name=self.var_name, + ) + + def update_timestamp(self): + # Update the timestamp of the object, should be called when the object is updated, e.g. __setattr__ and observer + current_time = get_timestamp_ns() + self.__dict__["last_update_timestamp"] = current_time + # TODO: + # Proxy.var_dict[self.__dict__["var_name"]].last_update_timestamp = current_time + + def register_object(self): + # get_global_registry().add_var(self, self.__dict__["var_name"]) + # TODO: implement the registry, we will need to make sure the registerred timestamp is updated and is consistent with the timestamp in the object + pass + + def dump_trace(self, phase, dump_loc): + # print(f"parameter: {self.var_name}, phase = {phase}, dump_loc = {dump_loc}") + # TODO + var_name = self.__dict__["var_name"] + # assert var_name is not None # '' is allowed as a var_name (root object) + # filter_by_tensor_version = proxy_config.dump_info_config[ + # "filter_by_tensor_version" + # ] + # if filter_by_tensor_version and phase == "update": + # if hasattr(obj, "_version"): + # if obj._version == Proxy.var_dict[self.__dict__["var_name"]].version: + # return + + last_update_timestamp = self.__dict__["last_update_timestamp"] + + # TODO + # if not isinstance(obj, torch.nn.Module): + dump_trace_VAR( + { + "process_id": self.process_id, + "thread_id": self.thread_id, + "time": last_update_timestamp, + "meta_vars": get_meta_vars(self), + "var_name": var_name, + "var_type": "torch.nn.Parameter", + "mode": phase, + "dump_loc": dump_loc, + "attributes": dump_attributes(self, self), + "type": TraceLineType.STATE_CHANGE, + } + ) + + +def proxy_parameter( + module: nn.Module, + logdir="proxy_log.log", + log_level=logging.INFO, + # TODO + # recurse=False, + parent_name="", + should_dump_trace=True, + from_call=False, + from_iter=False, + # TODO + # from_copy=False, +): + if in_dynamo(): + return + for name, t in list(module.named_parameters(recurse=False)): + module._parameters[name] = ProxyParameter( + t, + logdir, + log_level, + parent_name + "." + name, + should_dump_trace, + from_call, + from_iter, + ) + for name, child in module.named_children(): + proxy_parameter( + child, + logdir, + log_level, + parent_name + "." + name, + should_dump_trace, + from_call, + from_iter, + ) diff --git a/traincheck/utils.py b/traincheck/utils.py index 9e332094..944fd989 100644 --- a/traincheck/utils.py +++ b/traincheck/utils.py @@ -35,6 +35,14 @@ def safe_getattr(obj, attr, default=None): raise +def typename_compile(o): + try: + mod = getattr(type(o), "__module__", "") or "" + return f"{mod}.{type(o).__name__}" + except Exception: + return "compile_stage" + + def typename(o, is_runtime=False): if isinstance(o, torch.nn.Parameter): return "torch.nn.Parameter"