From e4c6af966db403c50c053eca681ff518adecd0cd Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Sat, 1 Nov 2025 21:54:04 +0100 Subject: [PATCH 01/17] fix(huggingface_bridge) correct per split constant detection ; feat(sample) add string support to globals --- src/plaid/bridges/huggingface_bridge.py | 49 ++++++++++++++++++++++--- src/plaid/containers/features.py | 5 +++ 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/src/plaid/bridges/huggingface_bridge.py b/src/plaid/bridges/huggingface_bridge.py index dd1e71e6..0b5fa266 100644 --- a/src/plaid/bridges/huggingface_bridge.py +++ b/src/plaid/bridges/huggingface_bridge.py @@ -98,7 +98,9 @@ def infer_hf_features_from_value(value: Any) -> Union[Value, Sequence]: dtype, np.int64 ): # very important to satisfy the CGNS standard return Value("int64") - elif np.issubdtype(dtype, np.dtype("|S1")): # pragma: no cover + elif np.issubdtype(dtype, np.dtype("|S1")) or np.issubdtype( + dtype, np.dtype(" 0, f"split {split_name} has no sample" + + dataset_size_bytes = ds.data.nbytes + target_shard_size_bytes = target_shard_size_mb * 1024 * 1024 + + n_shards = max( + 1, + (dataset_size_bytes + target_shard_size_bytes - 1) + // target_shard_size_bytes, + ) + num_shards[split_name] = min(n_samples, int(n_shards)) + + hf_dataset_dict.push_to_hub(repo_id, num_shards=num_shards, *args, **kwargs) def push_infos_to_hub( diff --git a/src/plaid/containers/features.py b/src/plaid/containers/features.py index cadf5770..484497fb 100644 --- a/src/plaid/containers/features.py +++ b/src/plaid/containers/features.py @@ -914,6 +914,11 @@ def add_global( else: base_node = self.init_base(1, 1, "Global", time) + if isinstance(global_array, str): # pragma: no cover + global_array = np.frombuffer( + global_array.encode("ascii"), dtype="S1", count=len(global_array) + ) + if CGU.getValueByPath(base_node, name) is None: CGL.newDataArray(base_node, name, value=global_array) else: From 4345030c26af200f10762a2a5214f80d4bfcb85d Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Sun, 2 Nov 2025 16:06:47 +0100 Subject: [PATCH 02/17] continue --- src/plaid/bridges/huggingface_bridge.py | 264 +++++++++-------------- tests/bridges/test_huggingface_bridge.py | 21 +- 2 files changed, 113 insertions(+), 172 deletions(-) diff --git a/src/plaid/bridges/huggingface_bridge.py b/src/plaid/bridges/huggingface_bridge.py index 0b5fa266..19064ee2 100644 --- a/src/plaid/bridges/huggingface_bridge.py +++ b/src/plaid/bridges/huggingface_bridge.py @@ -44,7 +44,6 @@ flatten_cgns_tree, unflatten_cgns_tree, ) -from plaid.utils.deprecation import deprecated logger = logging.getLogger(__name__) @@ -495,7 +494,7 @@ def to_plaid_sample( else: if isinstance(value, pa.ListArray): row[name] = np.stack(value.to_numpy(zero_copy_only=False)) - elif isinstance(value, pa.StringArray): + elif isinstance(value, pa.StringArray): # pragma: no cover row[name] = value.to_numpy(zero_copy_only=False) else: row[name] = value.to_numpy(zero_copy_only=True) @@ -1633,195 +1632,142 @@ def huggingface_description_to_infos( return infos -@deprecated( - "will be removed (this hf format will not be not maintained)", - version="0.1.9", - removal="0.2.0", -) -def create_string_for_huggingface_dataset_card( - description: dict, - download_size_bytes: int, - dataset_size_bytes: int, - nb_samples: int, - owner: str, - license: str, - zenodo_url: Optional[str] = None, - arxiv_paper_url: Optional[str] = None, +def update_dataset_card( + dataset_card: str, + license: str = "cc-by-sa-4.0", + infos: Optional[dict[str, dict[str, str]]] = None, pretty_name: Optional[str] = None, - size_categories: Optional[list[str]] = None, - task_categories: Optional[list[str]] = None, - tags: Optional[list[str]] = None, dataset_long_description: Optional[str] = None, - url_illustration: Optional[str] = None, + illustration_urls: Optional[list[str]] = None, + arxiv_paper_urls: Optional[list[str]] = None, ) -> str: - """Use this function for creating a dataset card, to upload together with the datase on the Hugging Face hub. - - Doing so ensure that load_dataset from the hub will populate the hf-dataset.description field, and be compatible for conversion to plaid. - - Without a dataset_card, the description field is lost. - - The parameters download_size_bytes and dataset_size_bytes can be determined after a - dataset has been uploaded on Hugging Face: - - manually by reading their values on the dataset page README.md, - - automatically as shown in the example below - - See `the hugginface examples `__ for a concrete use. + r"""Update a dataset card with PLAID-specific metadata and documentation. Args: - description (dict): Hugging Face dataset description. Obtained from - - description = hf_dataset.description - - description = generate_huggingface_description(infos, problem_definition) - download_size_bytes (int): the size of the dataset when downloaded from the hub - dataset_size_bytes (int): the size of the dataset when loaded in RAM - nb_samples (int): the number of samples in the dataset - owner (str): the owner of the dataset, usually a username or organization name on Hugging Face - license (str): the license of the dataset, e.g. "CC-BY-4.0", "CC0-1.0", etc. - zenodo_url (str, optional): the Zenodo URL of the dataset, if available - arxiv_paper_url (str, optional): the arxiv paper URL of the dataset, if available - pretty_name (str, optional): a human-readable name for the dataset, e.g. "PLAID Dataset" - size_categories (list[str], optional): size categories of the dataset, e.g. ["small", "medium", "large"] - task_categories (list[str], optional): task categories of the dataset, e.g. ["image-classification", "text-generation"] - tags (list[str], optional): tags for the dataset, e.g. ["3D", "simulation", "mesh"] - dataset_long_description (str, optional): a long description of the dataset, providing more details about its content and purpose - url_illustration (str, optional): a URL to an illustration image for the dataset, e.g. a screenshot or a sample mesh + dataset_card (str): The original dataset card content to update. + license (str, optional): The dataset license identifier. Defaults to "cc-by-sa-4.0". + infos (dict[str, dict[str, str]], optional): Dictionary containing dataset information + with "legal" and "data_production" sections. Defaults to None. + pretty_name (str, optional): A human-readable name for the dataset. Defaults to None. + dataset_long_description (str, optional): Detailed description of the dataset's content, + purpose, and characteristics. Defaults to None. + illustration_urls (list[str], optional): List of URLs to images illustrating the dataset. + Defaults to None. + arxiv_paper_urls (list[str], optional): List of URLs to related arXiv papers. + Defaults to None. Returns: - dataset (Dataset): the converted dataset - problem_definition (ProblemDefinition): the problem definition generated from the Hugging Face dataset + str: The updated dataset card content as a string. Example: - .. code-block:: python - - hf_dataset.push_to_hub("chanel/dataset") - - from datasets import load_dataset_builder - - datasetInfo = load_dataset_builder("chanel/dataset").__getstate__()['info'] - - from huggingface_hub import DatasetCard + ```python + # Create initial dataset card + card = "---\ndataset_name: my_dataset\n---" + + # Update with PLAID-specific content + updated_card = update_dataset_card( + dataset_card=card, + license="mit", + pretty_name="My PLAID Dataset", + dataset_long_description="This dataset contains...", + illustration_urls=["https://example.com/image.png"], + arxiv_paper_urls=["https://arxiv.org/abs/..."] + ) - card_text = create_string_for_huggingface_dataset_card( - description = description, - download_size_bytes = datasetInfo.download_size, - dataset_size_bytes = datasetInfo.dataset_size, - ...) - dataset_card = DatasetCard(card_text) - dataset_card.push_to_hub("chanel/dataset") + # Push to Hugging Face Hub + from huggingface_hub import DatasetCard + dataset_card = DatasetCard(updated_card) + dataset_card.push_to_hub("username/dataset") + ``` """ - str__ = f"""--- -license: {license} -""" - - if size_categories: - str__ += f"""size_categories: - {size_categories} -""" + lines = dataset_card.splitlines() - if task_categories: - str__ += f"""task_categories: - {task_categories} -""" + indices = [i for i, line in enumerate(lines) if line.strip() == "---"] + assert len(indices) >= 2, ( + "Cannot find two instances of '---', you should try to update a correct dataset_card." + ) + lines = lines[: indices[1] + 1] + + count = 1 + lines.insert(count, f"license: {license}") + count += 1 + lines.insert(count, "task_categories:") + count += 1 + lines.insert(count, "- graph-ml") + count += 1 if pretty_name: - str__ += f"""pretty_name: {pretty_name} -""" - - if tags: - str__ += f"""tags: - {tags} -""" - - str__ += f"""configs: - - config_name: default - data_files: - - split: all_samples - path: data/all_samples-* -dataset_info: - description: {description} - features: - - name: sample - dtype: binary - splits: - - name: all_samples - num_bytes: {dataset_size_bytes} - num_examples: {nb_samples} - download_size: {download_size_bytes} - dataset_size: {dataset_size_bytes} ---- - -# Dataset Card -""" - if url_illustration: - str__ += f"""![image/png]({url_illustration}) - -This dataset contains a single Hugging Face split, named 'all_samples'. - -The samples contains a single Hugging Face feature, named called "sample". - -Samples are instances of [plaid.containers.sample.Sample](https://plaid-lib.readthedocs.io/en/latest/autoapi/plaid/containers/sample/index.html#plaid.containers.sample.Sample). -Mesh objects included in samples follow the [CGNS](https://cgns.github.io/) standard, and can be converted in -[Muscat.Containers.Mesh.Mesh](https://muscat.readthedocs.io/en/latest/_source/Muscat.Containers.Mesh.html#Muscat.Containers.Mesh.Mesh). - + lines.insert(count, f"pretty_name: {pretty_name}") + count += 1 + lines.insert(count, "tags:") + count += 1 + lines.insert(count, "- physics learning") + count += 1 + lines.insert(count, "- geometry learning") + count += 1 + + str__ = "\n".join(lines) + "\n" + + if illustration_urls: + str__ += "

\n" + for url in illustration_urls: + str__ += f"{url}\n" + str__ += "

\n\n" + + if infos: + str__ += ( + f"```yaml\n{yaml.dump(infos, sort_keys=False, allow_unicode=True)}\n```" + ) + str__ += """ Example of commands: ```python -import pickle from datasets import load_dataset -from plaid import Sample +from plaid.bridges import huggingface_bridge -# Load the dataset -dataset = load_dataset("chanel/dataset", split="all_samples") - -# Get the first sample of the first split -split_names = list(dataset.description["split"].keys()) -ids_split_0 = dataset.description["split"][split_names[0]] -sample_0_split_0 = dataset[ids_split_0[0]]["sample"] -plaid_sample = Sample.model_validate(pickle.loads(sample_0_split_0)) -print("type(plaid_sample) =", type(plaid_sample)) +repo_id = "chanel/dataset" +pb_def_name = "pb_def_name" #`pb_def_name` is to choose from the repo `problem_definitions` folder -print("plaid_sample =", plaid_sample) +# Load the dataset +hf_datasetdict = load_dataset(repo_id) -# Get a field from the sample -field_names = plaid_sample.get_field_names() -field = plaid_sample.get_field(field_names[0]) -print("field_names[0] =", field_names[0]) +# Load addition required data +flat_cst, key_mappings = huggingface_bridge.load_tree_struct_from_hub(repo_id) +pb_def = huggingface_bridge.load_problem_definition_from_hub(repo_id, pb_def_name) -print("field.shape =", field.shape) +# Efficient reconstruction of plaid samples +for split_name, hf_dataset in hf_datasetdict.items(): + for i in range(len(hf_dataset)): + sample = huggingface_bridge.to_plaid_sample( + hf_dataset, + i, + flat_cst[split_name], + key_mappings["cgns_types"], + ) -# Get the mesh and convert it to Muscat -from Muscat.Bridges import CGNSBridge -CGNS_tree = plaid_sample.get_mesh() -mesh = CGNSBridge.CGNSToMesh(CGNS_tree) -print(mesh) +# Extract input and output features from samples: +for t in sample.get_all_mesh_times(): + for path in pb_def.get_in_features_identifiers(): + sample.get_feature_by_path(path=path, time=t) + for path in pb_def.get_out_features_identifiers(): + sample.get_feature_by_path(path=path, time=t) ``` - -## Dataset Details - -### Dataset Description - """ + str__ += "This dataset was generated in [PLAID](https://plaid-lib.readthedocs.io/), we refer to this documentation for additional details on how to extract data from `sample` objects.\n" if dataset_long_description: - str__ += f"""{dataset_long_description} -""" - - str__ += f"""- **Language:** [PLAID](https://plaid-lib.readthedocs.io/) -- **License:** {license} -- **Owner:** {owner} + str__ += f""" +### Dataset Description +{dataset_long_description} """ - if zenodo_url or arxiv_paper_url: + if arxiv_paper_urls: str__ += """ ### Dataset Sources +- **Papers:** """ - - if zenodo_url: - str__ += f"""- **Repository:** [Zenodo]({zenodo_url}) -""" - - if arxiv_paper_url: - str__ += f"""- **Paper:** [arxiv]({arxiv_paper_url}) -""" + for url in arxiv_paper_urls: + str__ += f" - [arxiv]({url})\n" return str__ diff --git a/tests/bridges/test_huggingface_bridge.py b/tests/bridges/test_huggingface_bridge.py index 452aca6d..26e3b143 100644 --- a/tests/bridges/test_huggingface_bridge.py +++ b/tests/bridges/test_huggingface_bridge.py @@ -316,20 +316,15 @@ def test_huggingface_description_to_infos(self, infos): huggingface_bridge.huggingface_description_to_infos(hf_description) # ---- Deprecated ---- - def test_create_string_for_huggingface_dataset_card(self, hf_dataset): - huggingface_bridge.create_string_for_huggingface_dataset_card( - description=hf_dataset.description, - download_size_bytes=10, - dataset_size_bytes=10, - nb_samples=10, - owner="Safran", + def test_create_string_for_huggingface_dataset_card(self, infos): + dataset_card = "---\ndataset_name: my_dataset\n---" + + huggingface_bridge.update_dataset_card( + dataset_card=dataset_card, license="cc-by-sa-4.0", - zenodo_url="https://zenodo.org/records/10124594", - arxiv_paper_url="https://arxiv.org/pdf/2305.12871", + infos=infos, pretty_name="2D quasistatic non-linear structural mechanics solutions", - size_categories=["n<1K"], - task_categories=["graph-ml"], - tags=["physics learning", "geometry learning"], dataset_long_description="my long description", - url_illustration="url3", + illustration_urls=["url0", "url1"], + arxiv_paper_urls=["url2"], ) From 77226e1eb5b673a05e4c096bfe7f455804110c69 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Mon, 3 Nov 2025 07:21:35 +0100 Subject: [PATCH 03/17] continue --- src/plaid/bridges/huggingface_bridge.py | 9 ++++----- tests/bridges/test_huggingface_bridge.py | 1 - 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/plaid/bridges/huggingface_bridge.py b/src/plaid/bridges/huggingface_bridge.py index 19064ee2..9aedffcb 100644 --- a/src/plaid/bridges/huggingface_bridge.py +++ b/src/plaid/bridges/huggingface_bridge.py @@ -1634,8 +1634,7 @@ def huggingface_description_to_infos( def update_dataset_card( dataset_card: str, - license: str = "cc-by-sa-4.0", - infos: Optional[dict[str, dict[str, str]]] = None, + infos: dict[str, dict[str, str]] = None, pretty_name: Optional[str] = None, dataset_long_description: Optional[str] = None, illustration_urls: Optional[list[str]] = None, @@ -1645,8 +1644,7 @@ def update_dataset_card( Args: dataset_card (str): The original dataset card content to update. - license (str, optional): The dataset license identifier. Defaults to "cc-by-sa-4.0". - infos (dict[str, dict[str, str]], optional): Dictionary containing dataset information + infos (dict[str, dict[str, str]]): Dictionary containing dataset information with "legal" and "data_production" sections. Defaults to None. pretty_name (str, optional): A human-readable name for the dataset. Defaults to None. dataset_long_description (str, optional): Detailed description of the dataset's content, @@ -1681,6 +1679,7 @@ def update_dataset_card( ``` """ lines = dataset_card.splitlines() + lines = [s for s in lines if not s.startswith("license")] indices = [i for i, line in enumerate(lines) if line.strip() == "---"] @@ -1690,7 +1689,7 @@ def update_dataset_card( lines = lines[: indices[1] + 1] count = 1 - lines.insert(count, f"license: {license}") + lines.insert(count, f"license: {infos['legal']['license']}") count += 1 lines.insert(count, "task_categories:") count += 1 diff --git a/tests/bridges/test_huggingface_bridge.py b/tests/bridges/test_huggingface_bridge.py index 26e3b143..af485a72 100644 --- a/tests/bridges/test_huggingface_bridge.py +++ b/tests/bridges/test_huggingface_bridge.py @@ -321,7 +321,6 @@ def test_create_string_for_huggingface_dataset_card(self, infos): huggingface_bridge.update_dataset_card( dataset_card=dataset_card, - license="cc-by-sa-4.0", infos=infos, pretty_name="2D quasistatic non-linear structural mechanics solutions", dataset_long_description="my long description", From 1ac20e25ceb89fec241f3fee6515bf63d61917da Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Mon, 3 Nov 2025 14:33:10 +0100 Subject: [PATCH 04/17] continute --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 50be0b62..93480e70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixes +- (samples/features) add string support to globals +- (huggingface bridge) correct split_constant tree derivation, add heuristic for number of shards usage in push_to_dict, robustify infer_hf_features_from_value with respect to numpy arrays of strings, modernize update_dataset_card + ### Removed ## [0.1.10] - 2025-10-29 From 10181287349321276f6561b2e71939a8935e6171 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Thu, 6 Nov 2025 19:38:56 +0100 Subject: [PATCH 05/17] continue --- src/plaid/bridges/huggingface_bridge.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/plaid/bridges/huggingface_bridge.py b/src/plaid/bridges/huggingface_bridge.py index 9aedffcb..3e86736d 100644 --- a/src/plaid/bridges/huggingface_bridge.py +++ b/src/plaid/bridges/huggingface_bridge.py @@ -239,6 +239,7 @@ def build_hf_sample(sample: Sample) -> tuple[dict[str, Any], list[str], dict[str def _generator_prepare_for_huggingface( generators: dict[str, Callable], + gen_kwargs: dict, verbose: bool = True, ) -> tuple[dict[str, dict[str, Any]], dict[str, Any], Features]: """Inspect PLAID dataset generators and infer Hugging Face feature schema. @@ -267,7 +268,6 @@ def _generator_prepare_for_huggingface( Raises: ValueError: If inconsistent CGNS types or feature types are found for the same path. """ - def values_equal(v1, v2): if isinstance(v1, np.ndarray) and isinstance(v2, np.ndarray): return np.array_equal(v1, v2) @@ -288,7 +288,7 @@ def values_equal(v1, v2): n_samples = 0 for sample in tqdm( - generator(), disable=not verbose, desc=f"Process split {split_name}" + generator(**gen_kwargs), disable=not verbose, desc=f"Process split {split_name}" ): # --- Build Hugging Face–compatible sample --- hf_sample, all_paths, sample_cgns_types = build_hf_sample(sample) @@ -563,6 +563,7 @@ def plaid_dataset_to_huggingface_datasetdict( main_splits: dict[str, IndexType], processes_number: int = 1, writer_batch_size: int = 1, + gen_kwargs: Optional[dict] = None, verbose: bool = False, ) -> tuple[datasets.DatasetDict, dict[str, Any], dict[str, Any]]: """Convert a PLAID dataset into a Hugging Face `datasets.DatasetDict`. @@ -617,7 +618,7 @@ def generator(dataset): } return plaid_generator_to_huggingface_datasetdict( - generators, processes_number, writer_batch_size, verbose + generators, processes_number, writer_batch_size, gen_kwargs, verbose ) @@ -625,6 +626,7 @@ def plaid_generator_to_huggingface_datasetdict( generators: dict[str, Callable], processes_number: int = 1, writer_batch_size: int = 1, + gen_kwargs: Optional[dict] = None, verbose: bool = False, ) -> tuple[datasets.DatasetDict, dict[str, Any], dict[str, Any]]: """Convert PLAID dataset generators into a Hugging Face `datasets.DatasetDict`. @@ -681,22 +683,25 @@ def plaid_generator_to_huggingface_datasetdict( >>> print(key_mappings["variable_features"][:3]) ['Zone1/FlowSolution/VelocityX', 'Zone1/FlowSolution/VelocityY', ...] """ + gen_kwargs = gen_kwargs or {} + flat_cst, key_mappings, hf_features = _generator_prepare_for_huggingface( - generators, verbose + generators, gen_kwargs, verbose ) all_features_keys = list(hf_features.keys()) - def generator_fn(gen_func, all_features_keys): - for sample in gen_func(): + def generator_fn(gen_func, all_features_keys, **kwargs): + for sample in gen_func(**kwargs): hf_sample, _, _ = build_hf_sample(sample) yield {path: hf_sample.get(path, None) for path in all_features_keys} _dict = {} for split_name, gen_func in generators.items(): - gen = partial(generator_fn, gen_func, all_features_keys) + gen = partial(generator_fn, all_features_keys=all_features_keys) _dict[split_name] = datasets.Dataset.from_generator( generator=gen, + gen_kwargs={"gen_func": gen_func, **gen_kwargs}, features=hf_features, num_proc=processes_number, writer_batch_size=writer_batch_size, From cb8c4b2c5a2b97e72c7887c9a6dba03f6e3a7e35 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Thu, 6 Nov 2025 19:40:16 +0100 Subject: [PATCH 06/17] continue --- CHANGELOG.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 93480e70..69620e8b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,10 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- (huggingface bridge) from_generator now compatible with `gen_kwargs` for parallel support. + ### Fixes -- (samples/features) add string support to globals -- (huggingface bridge) correct split_constant tree derivation, add heuristic for number of shards usage in push_to_dict, robustify infer_hf_features_from_value with respect to numpy arrays of strings, modernize update_dataset_card +- (samples/features) add string support to globals. +- (huggingface bridge) correct split_constant tree derivation, add heuristic for number of shards usage in push_to_dict, robustify infer_hf_features_from_value with respect to numpy arrays of strings, modernize update_dataset_card. ### Removed From da4a745ebc373d61e81392b695865f90cb756f65 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Thu, 6 Nov 2025 21:48:22 +0100 Subject: [PATCH 07/17] continue --- src/plaid/bridges/huggingface_bridge.py | 220 ++++++++++++++++++++++- tests/bridges/test_huggingface_bridge.py | 32 ++++ 2 files changed, 248 insertions(+), 4 deletions(-) diff --git a/src/plaid/bridges/huggingface_bridge.py b/src/plaid/bridges/huggingface_bridge.py index 3e86736d..2e6aef79 100644 --- a/src/plaid/bridges/huggingface_bridge.py +++ b/src/plaid/bridges/huggingface_bridge.py @@ -237,9 +237,202 @@ def build_hf_sample(sample: Sample) -> tuple[dict[str, Any], list[str], dict[str return hf_sample, all_paths, sample_cgns_types +# # -----PARALLEL TENTATIVE--------------------- +# import multiprocessing as mp +# import dill + +# def _values_equal(v1, v2): +# if isinstance(v1, np.ndarray) and isinstance(v2, np.ndarray): +# return np.array_equal(v1, v2) +# return v1 == v2 + +# def _process_shard(generator_func_serialized, shard_ids): +# """Process one shard: iterate over samples and return list of HF tuples.""" +# try: +# generator_func = dill.loads(generator_func_serialized) +# results = [] +# for sample in generator_func(shard_ids): +# hf_sample, all_paths, sample_cgns_types = build_hf_sample(sample) +# results.append((hf_sample, all_paths, sample_cgns_types)) +# return results +# except Exception as e: +# raise RuntimeError(f"Error processing shard {shard_ids}") from e + +# # -------------------------- +# # Shard-safe generator fixture +# # -------------------------- +# def generator_split_parallel(dataset, gen_kwargs) -> dict[str, callable]: +# """ +# Returns a dictionary of split_name -> generator function. +# Each shard is always a list, even if containing only one ID. +# """ +# generators_ = {} +# for split_name in gen_kwargs.keys(): + +# def generator_(shards_ids): +# for ids in shards_ids: +# if isinstance(ids, int): +# ids = [ids] +# for id in ids: +# yield dataset[id] + +# generators_[split_name] = generator_ +# return generators_ + +# # -------------------------- +# # Parallel _generator_prepare_for_huggingface +# # -------------------------- +# def _generator_prepare_for_huggingface( +# generators: dict[str, callable], +# gen_kwargs: dict = None, +# processes_number: int = 1, +# verbose: bool = True, +# ): +# """Parallelized version of _generator_prepare_for_huggingface by shard.""" +# gen_kwargs = gen_kwargs or {split_name: {} for split_name in generators.keys()} + +# split_flat_cst = {} +# split_var_path = {} +# split_all_paths = {} +# global_cgns_types = {} +# global_feature_types = {} + +# ctx = mp.get_context("spawn") + +# for split_name, generator_func in generators.items(): +# kwargs = gen_kwargs.get(split_name, {}) +# shards = kwargs.get("shards_ids", None) + +# split_constant_leaves = {} +# split_all_paths[split_name] = set() +# n_samples = 0 + +# if shards and processes_number > 1: +# serialized_gen = dill.dumps(generator_func) +# pool = ctx.Pool(processes=processes_number) + +# try: +# shard_results = pool.starmap( +# _process_shard, +# [(serialized_gen, shard) for shard in shards] +# ) +# finally: +# pool.close() +# pool.join() + +# # Flatten results +# for shard_batch in shard_results: +# for hf_sample, all_paths, sample_cgns_types in shard_batch: +# split_all_paths[split_name].update(hf_sample.keys()) +# global_cgns_types.update(sample_cgns_types) + +# # --- feature type inference --- +# for path in all_paths: +# value = hf_sample[path] +# if value is None: +# continue +# inferred = infer_hf_features_from_value(value) +# if path not in global_feature_types: +# global_feature_types[path] = inferred +# elif repr(global_feature_types[path]) != repr(inferred): +# raise ValueError(f"Feature type mismatch for {path} in split {split_name}") + +# # --- constant feature detection --- +# for path, value in hf_sample.items(): +# if path not in split_constant_leaves: +# split_constant_leaves[path] = {"value": value, "constant": True, "count": 1} +# else: +# entry = split_constant_leaves[path] +# entry["count"] += 1 +# if entry["constant"] and not _values_equal(entry["value"], value): +# entry["constant"] = False + +# n_samples += 1 +# else: +# # Sequential fallback +# print("Sequential fallback") +# 1./0. +# all_samples = generator_func(**kwargs) +# for sample in tqdm(all_samples, disable=not verbose, desc=f"Process split {split_name}"): +# hf_sample, all_paths, sample_cgns_types = build_hf_sample(sample) +# split_all_paths[split_name].update(hf_sample.keys()) +# global_cgns_types.update(sample_cgns_types) + +# for path in all_paths: +# value = hf_sample[path] +# if value is None: +# continue +# inferred = infer_hf_features_from_value(value) +# if path not in global_feature_types: +# global_feature_types[path] = inferred +# elif repr(global_feature_types[path]) != repr(inferred): +# raise ValueError(f"Feature type mismatch for {path} in split {split_name}") + +# for path, value in hf_sample.items(): +# if path not in split_constant_leaves: +# split_constant_leaves[path] = {"value": value, "constant": True, "count": 1} +# else: +# entry = split_constant_leaves[path] +# entry["count"] += 1 +# if entry["constant"] and not _values_equal(entry["value"], value): +# entry["constant"] = False + +# n_samples += 1 + +# # --- finalize constants --- +# for p, e in split_constant_leaves.items(): +# if e["count"] < n_samples: +# split_constant_leaves[p]["constant"] = False + +# split_flat_cst[split_name] = dict( +# sorted(((p, e["value"]) for p, e in split_constant_leaves.items() if e["constant"]), +# key=lambda x: x[0]) +# ) +# split_var_path[split_name] = {p for p in split_all_paths[split_name] if p not in split_flat_cst[split_name]} + +# # --- build HF features --- +# var_features = sorted(list(set().union(*split_var_path.values()))) +# if len(var_features) == 0: +# raise ValueError("no variable feature found, is your dataset variable through samples?") + +# for split_name in split_flat_cst.keys(): +# for path in var_features: +# if not path.endswith("_times") and path not in split_all_paths[split_name]: +# split_flat_cst[split_name][path + "_times"] = None +# if path in split_flat_cst[split_name]: +# split_flat_cst[split_name].pop(path) + +# cst_features = {split_name: sorted(list(cst.keys())) for split_name, cst in split_flat_cst.items()} +# first_split, first_value = next(iter(cst_features.items()), (None, None)) +# for split, value in cst_features.items(): +# assert value == first_value, f"cst_features differ for split '{split}' (vs '{first_split}')" +# cst_features = first_value + +# hf_features_map = {} +# for k in var_features: +# if k.endswith("_times"): +# hf_features_map[k] = Sequence(Value("float64")) +# else: +# hf_features_map[k] = global_feature_types[k] +# hf_features = Features(hf_features_map) + +# var_features = [path for path in var_features if not path.endswith("_times")] +# cst_features = [path for path in cst_features if not path.endswith("_times")] + +# key_mappings = { +# "variable_features": var_features, +# "constant_features": cst_features, +# "cgns_types": global_cgns_types, +# } + +# return split_flat_cst, key_mappings, hf_features +# # ------------------------------------------- + + def _generator_prepare_for_huggingface( generators: dict[str, Callable], gen_kwargs: dict, + processes_number: int = 1, verbose: bool = True, ) -> tuple[dict[str, dict[str, Any]], dict[str, Any], Features]: """Inspect PLAID dataset generators and infer Hugging Face feature schema. @@ -254,6 +447,10 @@ def _generator_prepare_for_huggingface( generators (dict[str, Callable]): Mapping from split names to callables returning sample generators. Each sample must have `sample.features.data[0.0]` compatible with `flatten_cgns_tree`. + gen_kwargs (dict, optional, default=None): + Optional mapping from split names to dictionaries of keyword arguments + to be passed to each generator function, used for parallelization. + processes_number (int, optional): Number of parallel processes to use. verbose (bool, optional): If True, displays progress bars while processing splits. Returns: @@ -268,6 +465,8 @@ def _generator_prepare_for_huggingface( Raises: ValueError: If inconsistent CGNS types or feature types are found for the same path. """ + processes_number + def values_equal(v1, v2): if isinstance(v1, np.ndarray) and isinstance(v2, np.ndarray): return np.array_equal(v1, v2) @@ -288,7 +487,9 @@ def values_equal(v1, v2): n_samples = 0 for sample in tqdm( - generator(**gen_kwargs), disable=not verbose, desc=f"Process split {split_name}" + generator(**gen_kwargs[split_name]), + disable=not verbose, + desc=f"Process split {split_name}", ): # --- Build Hugging Face–compatible sample --- hf_sample, all_paths, sample_cgns_types = build_hf_sample(sample) @@ -583,6 +784,9 @@ def plaid_dataset_to_huggingface_datasetdict( Number of parallel processes to use when writing the Hugging Face dataset. writer_batch_size (int, optional, default=1): Batch size used when writing samples to disk in Hugging Face format. + gen_kwargs (dict, optional, default=None): + Optional mapping from split names to dictionaries of keyword arguments + to be passed to each generator function, used for parallelization. verbose (bool, optional, default=False): If True, print progress and debug information. @@ -646,6 +850,9 @@ def plaid_generator_to_huggingface_datasetdict( the dataset from the generators. writer_batch_size (int, optional, default=1): Batch size used when writing samples to disk in Hugging Face format. + gen_kwargs (dict, optional, default=None): + Optional mapping from split names to dictionaries of keyword arguments + to be passed to each generator function, used for parallelization. verbose (bool, optional, default=False): If True, displays progress bars and diagnostic messages. @@ -683,10 +890,15 @@ def plaid_generator_to_huggingface_datasetdict( >>> print(key_mappings["variable_features"][:3]) ['Zone1/FlowSolution/VelocityX', 'Zone1/FlowSolution/VelocityY', ...] """ - gen_kwargs = gen_kwargs or {} + if processes_number > 1: + assert gen_kwargs is not None, ( + "When using multiple processes, gen_kwargs must be provided." + ) + + gen_kwargs = gen_kwargs or {split_name: {} for split_name in generators.keys()} flat_cst, key_mappings, hf_features = _generator_prepare_for_huggingface( - generators, gen_kwargs, verbose + generators, gen_kwargs, processes_number, verbose ) all_features_keys = list(hf_features.keys()) @@ -701,7 +913,7 @@ def generator_fn(gen_func, all_features_keys, **kwargs): gen = partial(generator_fn, all_features_keys=all_features_keys) _dict[split_name] = datasets.Dataset.from_generator( generator=gen, - gen_kwargs={"gen_func": gen_func, **gen_kwargs}, + gen_kwargs={"gen_func": gen_func, **gen_kwargs[split_name]}, features=hf_features, num_proc=processes_number, writer_batch_size=writer_batch_size, diff --git a/tests/bridges/test_huggingface_bridge.py b/tests/bridges/test_huggingface_bridge.py index af485a72..6b6b7d37 100644 --- a/tests/bridges/test_huggingface_bridge.py +++ b/tests/bridges/test_huggingface_bridge.py @@ -71,6 +71,33 @@ def generator_(ids=ids): return generators_ +@pytest.fixture() +def gen_kwargs(problem_definition) -> dict[str, dict]: + gen_kwargs = {} + for split_name, ids in problem_definition.get_split().items(): + mid = len(ids) // 2 + gen_kwargs[split_name] = {"shards_ids": [ids[:mid], ids[mid:]]} + return gen_kwargs + + +@pytest.fixture() +def generator_split_parallel(dataset, gen_kwargs) -> dict[str, Callable]: + generators_ = {} + + for split_name in gen_kwargs.keys(): + + def generator_(shards_ids): + for ids in shards_ids: + if isinstance(ids, int): + ids = [ids] + for id in ids: + yield dataset[id] + + generators_[split_name] = generator_ + + return generators_ + + @pytest.fixture() def generator_binary(dataset) -> Callable: def generator_(): @@ -171,6 +198,11 @@ def test_with_generator(self, generator_split): enforce_shapes=True, ) + def test_with_generator_parallel(self, gen_kwargs, generator_split_parallel): + huggingface_bridge.plaid_generator_to_huggingface_datasetdict( + generator_split_parallel, processes_number=2, gen_kwargs=gen_kwargs + ) + # ------------------------------------------------------------------------------ # HUGGING FACE INTERACTIONS ON DISK # ------------------------------------------------------------------------------ From dd7a54478ada23f0caf7bdd337778168b824f3a0 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Sat, 8 Nov 2025 16:53:01 +0100 Subject: [PATCH 08/17] continue --- src/plaid/bridges/huggingface_bridge.py | 599 +++++++++++++++-------- tests/bridges/test_huggingface_bridge.py | 28 +- 2 files changed, 401 insertions(+), 226 deletions(-) diff --git a/src/plaid/bridges/huggingface_bridge.py b/src/plaid/bridges/huggingface_bridge.py index 2e6aef79..746f165e 100644 --- a/src/plaid/bridges/huggingface_bridge.py +++ b/src/plaid/bridges/huggingface_bridge.py @@ -8,10 +8,12 @@ # import hashlib import io +import multiprocessing as mp import os import pickle import shutil import sys +import traceback from functools import partial from multiprocessing import Pool from pathlib import Path @@ -237,181 +239,355 @@ def build_hf_sample(sample: Sample) -> tuple[dict[str, Any], list[str], dict[str return hf_sample, all_paths, sample_cgns_types -# # -----PARALLEL TENTATIVE--------------------- -# import multiprocessing as mp -# import dill - -# def _values_equal(v1, v2): -# if isinstance(v1, np.ndarray) and isinstance(v2, np.ndarray): -# return np.array_equal(v1, v2) -# return v1 == v2 - -# def _process_shard(generator_func_serialized, shard_ids): -# """Process one shard: iterate over samples and return list of HF tuples.""" -# try: -# generator_func = dill.loads(generator_func_serialized) -# results = [] -# for sample in generator_func(shard_ids): -# hf_sample, all_paths, sample_cgns_types = build_hf_sample(sample) -# results.append((hf_sample, all_paths, sample_cgns_types)) -# return results -# except Exception as e: -# raise RuntimeError(f"Error processing shard {shard_ids}") from e - -# # -------------------------- -# # Shard-safe generator fixture -# # -------------------------- -# def generator_split_parallel(dataset, gen_kwargs) -> dict[str, callable]: -# """ -# Returns a dictionary of split_name -> generator function. -# Each shard is always a list, even if containing only one ID. -# """ -# generators_ = {} -# for split_name in gen_kwargs.keys(): - -# def generator_(shards_ids): -# for ids in shards_ids: -# if isinstance(ids, int): -# ids = [ids] -# for id in ids: -# yield dataset[id] - -# generators_[split_name] = generator_ -# return generators_ - -# # -------------------------- -# # Parallel _generator_prepare_for_huggingface -# # -------------------------- +def _hash_value(value): + """Compute a hash for a value (np.ndarray or basic types).""" + if isinstance(value, np.ndarray): + return hashlib.md5(value.view(np.uint8)).hexdigest() + return hashlib.md5(str(value).encode("utf-8")).hexdigest() + + +def process_shard( + shard_ids: list[IndexType], + generator_fn: Callable[[list[list[IndexType]]], Any], + progress: Any, + n_proc: int, +) -> tuple[ + set[str], + dict[str, str], + dict[str, Union[Value, Sequence]], + dict[str, dict[str, Union[str, bool, int]]], + int, +]: + """Process a single shard of sample ids and collect per-shard metadata. + + This function drives a shard-level pass over samples produced by `generator_fn`. + For each sample it: + - flattens the sample into Hugging Face friendly arrays (build_hf_sample), + - collects observed flattened paths, + - aggregates CGNS type metadata, + - infers Hugging Face feature types for each path, + - detects per-path constants using a content hash, + - updates progress (either a multiprocessing.Queue or a tqdm progress bar). + + Args: + shard_ids (list[IndexType]): Sequence of sample ids (a single shard) to process. + generator_fn (Callable): Generator function accepting a list of shard id sequences + and yielding Sample objects for those ids. + progress (Any): Progress reporter; either a multiprocessing.Queue (for parallel + execution) or a tqdm progress bar object (for sequential execution). + n_proc (int): Number of worker processes used by the caller (used to decide + how to report progress). + + Returns: + tuple: + - split_all_paths (set[str]): Set of all flattened feature paths observed in the shard. + - shard_global_cgns_types (dict[str, str]): Mapping path -> CGNS node type observed in the shard. + - shard_global_feature_types (dict[str, Union[Value, Sequence]]): Inferred HF feature types per path. + - split_constant_leaves (dict[str, dict]): Per-path metadata for constant detection. Each entry + is a dict with keys "hash" (str), "constant" (bool) and "count" (int). + - n_samples_processed (int): Number of samples processed in this shard. + + Raises: + ValueError: If inconsistent feature types are detected for the same path within the shard. + """ + split_constant_leaves = {} + split_all_paths = set() + shard_global_cgns_types = {} + shard_global_feature_types = {} + shards_to_process = [shard_ids] + + for sample in generator_fn(shards_to_process): + hf_sample, all_paths, sample_cgns_types = build_hf_sample(sample) + + split_all_paths.update(hf_sample.keys()) + shard_global_cgns_types.update(sample_cgns_types) + + # Feature type inference + for path in all_paths: + value = hf_sample[path] + if value is None: + continue + inferred = infer_hf_features_from_value(value) + if path not in shard_global_feature_types: + shard_global_feature_types[path] = inferred + elif repr(shard_global_feature_types[path]) != repr( + inferred + ): # pragma: no cover + raise ValueError(f"Feature type mismatch for {path} in shard") + + # Constant detection using hashes + for path, value in hf_sample.items(): + h = _hash_value(value) + if path not in split_constant_leaves: + split_constant_leaves[path] = {"hash": h, "constant": True, "count": 1} + else: + entry = split_constant_leaves[path] + entry["count"] += 1 + if entry["constant"] and entry["hash"] != h: + entry["constant"] = False + + # --- Update progress --- + if n_proc > 1: + progress.put(1) # pragma: no cover + else: + progress.update(1) + + return ( + split_all_paths, + shard_global_cgns_types, + shard_global_feature_types, + split_constant_leaves, + len(shard_ids), + ) + + +def preprocess_splits( + generators: dict[str, Callable], + gen_kwargs: dict[str, dict[str, list[IndexType]]], + processes_number: int = 1, + verbose: bool = True, +) -> tuple[ + dict[str, set[str]], + dict[str, dict[str, Any]], + dict[str, set[str]], + dict[str, str], + dict[str, Union[Value, Sequence]], +]: + """Pre-process dataset splits: inspect samples to infer features, constants and CGNS metadata. + + This function iterates over the provided split generators (optionally in parallel), + flattens each PLAID sample into Hugging Face friendly arrays, detects constant + CGNS leaves (features identical across all samples in a split), infers global + Hugging Face feature types, and aggregates CGNS type metadata. + + The work is sharded per-split and each shard is processed by `process_shard`. + In parallel mode, progress is updated via a multiprocessing.Queue; otherwise a + tqdm progress bar is used. + + Args: + generators (dict[str, Callable]): + Mapping from split name to a generator function. Each generator must + accept a single argument (a sequence of shard ids) and yield PLAID samples. + gen_kwargs (dict[str, dict[str, list[IndexType]]]): + Per-split kwargs used to drive generator invocation (e.g. {"train": {"shards_ids": [...]}}). + processes_number (int, optional): + Number of worker processes to use for shard-level parallelism. Defaults to 1. + verbose (bool, optional): + If True, displays progress bars. Defaults to True. + + Returns: + tuple: + - split_all_paths (dict[str, set[str]]): + For each split, the set of all observed flattened feature paths (including "_times" keys). + - split_flat_cst (dict[str, dict[str, Any]]): + For each split, a mapping of constant feature path -> value (constant parts of the tree). + - split_var_path (dict[str, set[str]]): + For each split, the set of variable feature paths (non-constant). + - global_cgns_types (dict[str, str]): + Aggregated mapping from flattened path -> CGNS node type. + - global_feature_types (dict[str, Union[Value, Sequence]]): + Aggregated inferred Hugging Face feature types for each variable path. + + Raises: + ValueError: If inconsistent feature types or CGNS types are detected across shards/splits. + """ + global_cgns_types = {} + global_feature_types = {} + split_flat_cst = {} + split_var_path = {} + split_all_paths = {} + + for split_name, generator_fn in generators.items(): + shards_ids_list = gen_kwargs[split_name].get("shards_ids", [None]) + n_proc = processes_number or len(shards_ids_list) + n_proc = max(1, n_proc) + + shards_data = [] + + if n_proc == 1: + progress_total = sum(len(shard) for shard in shards_ids_list) + with tqdm( + total=progress_total, + disable=not verbose, + desc=f"Pre-process split {split_name}", + ) as pbar: + for shard_ids in shards_ids_list: + shards_data.append( + process_shard( + shard_ids, + generator_fn, + pbar, + n_proc, + ) + ) + + else: # pragma: no cover (pytest not working with parallel mode) + # --- Parallel execution --- + manager = None + pool = None + + try: + manager = mp.Manager() + progress_queue = manager.Queue() + + # --- Run shards in parallel --- + with mp.Pool(n_proc) as pool: + results = [ + pool.apply_async( + process_shard, + args=( + shard_ids, + generator_fn, + progress_queue, + n_proc, + ), + ) + for shard_ids in shards_ids_list + ] + + total_samples = sum(len(shard) for shard in shards_ids_list) + with tqdm( + total=total_samples, + disable=not verbose, + desc=f"Pre-process split {split_name}", + ) as pbar: + completed = 0 + while completed < total_samples: + increment = progress_queue.get() + pbar.update(increment) + completed += increment + + for r in results: + try: + data = r.get() # this raises if the worker crashed + shards_data.append(data) + except Exception: + traceback.print_exc() + # Optional: terminate pool early if one shard fails + pool.terminate() + raise + + except Exception: + traceback.print_exc() + raise + finally: + # Always clean up multiprocessing objects + if pool is not None: + pool.terminate() + pool.join() + if manager is not None: + manager.shutdown() + + # --- Merge shard results --- + split_all_paths[split_name] = set() + split_constant_hashes = {} + n_samples_total = 0 + + for ( + all_paths, + shard_cgns, + shard_features, + shard_constants, + n_samples, + ) in shards_data: + split_all_paths[split_name].update(all_paths) + global_cgns_types.update(shard_cgns) + + for path, inferred in shard_features.items(): + if path not in global_feature_types: + global_feature_types[path] = inferred + elif repr(global_feature_types[path]) != repr( + inferred + ): # pragma: no cover + raise ValueError( + f"Feature type mismatch for {path} in split {split_name}" + ) + + for path, entry in shard_constants.items(): + if path not in split_constant_hashes: + split_constant_hashes[path] = entry + else: + existing = split_constant_hashes[path] + existing["constant"] = existing["constant"] and entry["constant"] + existing["count"] += entry["count"] + + n_samples_total += n_samples + + # --- Finalize constants by inspecting first sample --- + # Only paths marked constant across all samples + constant_paths = [ + p + for p, e in split_constant_hashes.items() + if e["constant"] and e["count"] == n_samples_total + ] + + # Inspect first sample to get actual values + first_sample = next(generator_fn([shards_ids_list[0]])) + hf_sample, _, _ = build_hf_sample(first_sample) + + split_flat_cst[split_name] = {p: hf_sample[p] for p in sorted(constant_paths)} + split_var_path[split_name] = { + p + for p in split_all_paths[split_name] + if p not in split_flat_cst[split_name] + } + + global_feature_types = { + p: global_feature_types[p] for p in sorted(global_feature_types) + } + + return ( + split_all_paths, + split_flat_cst, + split_var_path, + global_cgns_types, + global_feature_types, + ) + + # def _generator_prepare_for_huggingface( -# generators: dict[str, callable], -# gen_kwargs: dict = None, +# generators: dict[str, Callable], +# gen_kwargs: dict, # processes_number: int = 1, # verbose: bool = True, # ): -# """Parallelized version of _generator_prepare_for_huggingface by shard.""" -# gen_kwargs = gen_kwargs or {split_name: {} for split_name in generators.keys()} - -# split_flat_cst = {} -# split_var_path = {} -# split_all_paths = {} -# global_cgns_types = {} -# global_feature_types = {} - -# ctx = mp.get_context("spawn") - -# for split_name, generator_func in generators.items(): -# kwargs = gen_kwargs.get(split_name, {}) -# shards = kwargs.get("shards_ids", None) - -# split_constant_leaves = {} -# split_all_paths[split_name] = set() -# n_samples = 0 - -# if shards and processes_number > 1: -# serialized_gen = dill.dumps(generator_func) -# pool = ctx.Pool(processes=processes_number) - -# try: -# shard_results = pool.starmap( -# _process_shard, -# [(serialized_gen, shard) for shard in shards] -# ) -# finally: -# pool.close() -# pool.join() - -# # Flatten results -# for shard_batch in shard_results: -# for hf_sample, all_paths, sample_cgns_types in shard_batch: -# split_all_paths[split_name].update(hf_sample.keys()) -# global_cgns_types.update(sample_cgns_types) - -# # --- feature type inference --- -# for path in all_paths: -# value = hf_sample[path] -# if value is None: -# continue -# inferred = infer_hf_features_from_value(value) -# if path not in global_feature_types: -# global_feature_types[path] = inferred -# elif repr(global_feature_types[path]) != repr(inferred): -# raise ValueError(f"Feature type mismatch for {path} in split {split_name}") - -# # --- constant feature detection --- -# for path, value in hf_sample.items(): -# if path not in split_constant_leaves: -# split_constant_leaves[path] = {"value": value, "constant": True, "count": 1} -# else: -# entry = split_constant_leaves[path] -# entry["count"] += 1 -# if entry["constant"] and not _values_equal(entry["value"], value): -# entry["constant"] = False - -# n_samples += 1 -# else: -# # Sequential fallback -# print("Sequential fallback") -# 1./0. -# all_samples = generator_func(**kwargs) -# for sample in tqdm(all_samples, disable=not verbose, desc=f"Process split {split_name}"): -# hf_sample, all_paths, sample_cgns_types = build_hf_sample(sample) -# split_all_paths[split_name].update(hf_sample.keys()) -# global_cgns_types.update(sample_cgns_types) - -# for path in all_paths: -# value = hf_sample[path] -# if value is None: -# continue -# inferred = infer_hf_features_from_value(value) -# if path not in global_feature_types: -# global_feature_types[path] = inferred -# elif repr(global_feature_types[path]) != repr(inferred): -# raise ValueError(f"Feature type mismatch for {path} in split {split_name}") - -# for path, value in hf_sample.items(): -# if path not in split_constant_leaves: -# split_constant_leaves[path] = {"value": value, "constant": True, "count": 1} -# else: -# entry = split_constant_leaves[path] -# entry["count"] += 1 -# if entry["constant"] and not _values_equal(entry["value"], value): -# entry["constant"] = False - -# n_samples += 1 - -# # --- finalize constants --- -# for p, e in split_constant_leaves.items(): -# if e["count"] < n_samples: -# split_constant_leaves[p]["constant"] = False - -# split_flat_cst[split_name] = dict( -# sorted(((p, e["value"]) for p, e in split_constant_leaves.items() if e["constant"]), -# key=lambda x: x[0]) -# ) -# split_var_path[split_name] = {p for p in split_all_paths[split_name] if p not in split_flat_cst[split_name]} +# ( +# split_all_paths, +# split_flat_cst, +# split_var_path, +# global_cgns_types, +# global_feature_types, +# ) = preprocess_splits(generators, gen_kwargs, processes_number, verbose) # # --- build HF features --- # var_features = sorted(list(set().union(*split_var_path.values()))) -# if len(var_features) == 0: -# raise ValueError("no variable feature found, is your dataset variable through samples?") +# if len(var_features) == 0: # pragma: no cover +# raise ValueError( +# "no variable feature found, is your dataset variable through samples?" +# ) # for split_name in split_flat_cst.keys(): # for path in var_features: # if not path.endswith("_times") and path not in split_all_paths[split_name]: -# split_flat_cst[split_name][path + "_times"] = None +# split_flat_cst[split_name][path + "_times"] = None # pragma: no cover # if path in split_flat_cst[split_name]: -# split_flat_cst[split_name].pop(path) +# split_flat_cst[split_name].pop(path) # pragma: no cover -# cst_features = {split_name: sorted(list(cst.keys())) for split_name, cst in split_flat_cst.items()} +# cst_features = { +# split_name: sorted(list(cst.keys())) +# for split_name, cst in split_flat_cst.items() +# } # first_split, first_value = next(iter(cst_features.items()), (None, None)) # for split, value in cst_features.items(): -# assert value == first_value, f"cst_features differ for split '{split}' (vs '{first_split}')" +# assert value == first_value, ( +# f"cst_features differ for split '{split}' (vs '{first_split}')" +# ) # cst_features = first_value # hf_features_map = {} # for k in var_features: # if k.endswith("_times"): -# hf_features_map[k] = Sequence(Value("float64")) +# hf_features_map[k] = Sequence(Value("float64")) # pragma: no cover # else: # hf_features_map[k] = global_feature_types[k] # hf_features = Features(hf_features_map) @@ -426,9 +602,10 @@ def build_hf_sample(sample: Sample) -> tuple[dict[str, Any], list[str], dict[str # } # return split_flat_cst, key_mappings, hf_features -# # ------------------------------------------- +# ------------------------------------------- +# --------- Sequential version def _generator_prepare_for_huggingface( generators: dict[str, Callable], gen_kwargs: dict, @@ -489,7 +666,7 @@ def values_equal(v1, v2): for sample in tqdm( generator(**gen_kwargs[split_name]), disable=not verbose, - desc=f"Process split {split_name}", + desc=f"Pre-process split {split_name}", ): # --- Build Hugging Face–compatible sample --- hf_sample, all_paths, sample_cgns_types = build_hf_sample(sample) @@ -764,7 +941,6 @@ def plaid_dataset_to_huggingface_datasetdict( main_splits: dict[str, IndexType], processes_number: int = 1, writer_batch_size: int = 1, - gen_kwargs: Optional[dict] = None, verbose: bool = False, ) -> tuple[datasets.DatasetDict, dict[str, Any], dict[str, Any]]: """Convert a PLAID dataset into a Hugging Face `datasets.DatasetDict`. @@ -784,9 +960,6 @@ def plaid_dataset_to_huggingface_datasetdict( Number of parallel processes to use when writing the Hugging Face dataset. writer_batch_size (int, optional, default=1): Batch size used when writing samples to disk in Hugging Face format. - gen_kwargs (dict, optional, default=None): - Optional mapping from split names to dictionaries of keyword arguments - to be passed to each generator function, used for parallelization. verbose (bool, optional, default=False): If True, print progress and debug information. @@ -812,25 +985,29 @@ def plaid_dataset_to_huggingface_datasetdict( }) """ - def generator(dataset): - for sample in dataset: - yield sample + def generator_(shards_ids): + for ids in shards_ids: + if isinstance(ids, int): + ids = [ids] # pragma: no cover + for id in ids: + yield dataset[id] + + generators = {split_name: generator_ for split_name in main_splits.keys()} - generators = { - split_name: partial(generator, dataset[ids]) - for split_name, ids in main_splits.items() + gen_kwargs = { + split_name: {"shards_ids": [ids]} for split_name, ids in main_splits.items() } return plaid_generator_to_huggingface_datasetdict( - generators, processes_number, writer_batch_size, gen_kwargs, verbose + generators, gen_kwargs, processes_number, writer_batch_size, verbose ) def plaid_generator_to_huggingface_datasetdict( generators: dict[str, Callable], + gen_kwargs: dict[str, dict[str, list[IndexType]]], processes_number: int = 1, writer_batch_size: int = 1, - gen_kwargs: Optional[dict] = None, verbose: bool = False, ) -> tuple[datasets.DatasetDict, dict[str, Any], dict[str, Any]]: """Convert PLAID dataset generators into a Hugging Face `datasets.DatasetDict`. @@ -890,13 +1067,6 @@ def plaid_generator_to_huggingface_datasetdict( >>> print(key_mappings["variable_features"][:3]) ['Zone1/FlowSolution/VelocityX', 'Zone1/FlowSolution/VelocityY', ...] """ - if processes_number > 1: - assert gen_kwargs is not None, ( - "When using multiple processes, gen_kwargs must be provided." - ) - - gen_kwargs = gen_kwargs or {split_name: {} for split_name in generators.keys()} - flat_cst, key_mappings, hf_features = _generator_prepare_for_huggingface( generators, gen_kwargs, processes_number, verbose ) @@ -928,6 +1098,26 @@ def generator_fn(gen_func, all_features_keys, **kwargs): # ------------------------------------------------------------------------------ +def _compute_num_shards(hf_dataset_dict: datasets.DatasetDict) -> dict[str, int]: + target_shard_size_mb = 500 + + num_shards = {} + for split_name, ds in hf_dataset_dict.items(): + n_samples = len(ds) + assert n_samples > 0, f"split {split_name} has no sample" + + dataset_size_bytes = ds.data.nbytes + target_shard_size_bytes = target_shard_size_mb * 1024 * 1024 + + n_shards = max( + 1, + (dataset_size_bytes + target_shard_size_bytes - 1) + // target_shard_size_bytes, + ) + num_shards[split_name] = min(n_samples, int(n_shards)) + return num_shards + + def instantiate_plaid_datasetdict_from_hub( repo_id: str, enforce_shapes: bool = True, @@ -1144,7 +1334,7 @@ def load_tree_struct_from_hub( def push_dataset_dict_to_hub( - repo_id: str, hf_dataset_dict: datasets.DatasetDict, *args, **kwargs + repo_id: str, hf_dataset_dict: datasets.DatasetDict, **kwargs ) -> None: # pragma: no cover (not tested in unit tests) """Push a Hugging Face `DatasetDict` to the Hugging Face Hub. @@ -1165,9 +1355,6 @@ def push_dataset_dict_to_hub( (e.g. `"username/dataset_name"`). hf_dataset_dict (datasets.DatasetDict): The Hugging Face dataset dictionary to push. - *args: - Positional arguments forwarded to - [`DatasetDict.push_to_hub`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.DatasetDict.push_to_hub). **kwargs: Keyword arguments forwarded to [`DatasetDict.push_to_hub`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.DatasetDict.push_to_hub). @@ -1175,24 +1362,20 @@ def push_dataset_dict_to_hub( Returns: None """ - target_shard_size_mb = 500 - - num_shards = {} - for split_name, ds in hf_dataset_dict.items(): - n_samples = len(ds) - assert n_samples > 0, f"split {split_name} has no sample" - - dataset_size_bytes = ds.data.nbytes - target_shard_size_bytes = target_shard_size_mb * 1024 * 1024 - - n_shards = max( - 1, - (dataset_size_bytes + target_shard_size_bytes - 1) - // target_shard_size_bytes, - ) - num_shards[split_name] = min(n_samples, int(n_shards)) + num_shards = _compute_num_shards(hf_dataset_dict) + num_proc = kwargs.get("num_proc", None) + if num_proc is not None: # pragma: no cover + min_num_shards = min(num_shards.values()) + if min_num_shards < num_proc: + logger.warning( + f"num_proc chaged from {num_proc} to 1 to safely adapt for num_shards={num_shards}" + ) + num_proc = 1 + del kwargs["num_proc"] - hf_dataset_dict.push_to_hub(repo_id, num_shards=num_shards, *args, **kwargs) + hf_dataset_dict.push_to_hub( + repo_id, num_shards=num_shards, num_proc=num_proc, **kwargs + ) def push_infos_to_hub( @@ -1398,7 +1581,7 @@ def load_tree_struct_from_disk( def save_dataset_dict_to_disk( - path: Union[str, Path], hf_dataset_dict: datasets.DatasetDict, *args, **kwargs + path: Union[str, Path], hf_dataset_dict: datasets.DatasetDict, **kwargs ) -> None: """Save a Hugging Face DatasetDict to disk. @@ -1408,9 +1591,6 @@ def save_dataset_dict_to_disk( Args: path (Union[str, Path]): Directory path where the DatasetDict will be saved. hf_dataset_dict (datasets.DatasetDict): The Hugging Face DatasetDict to save. - *args: - Positional arguments forwarded to - [`DatasetDict.save_to_disk`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.DatasetDict.save_to_disk). **kwargs: Keyword arguments forwarded to [`DatasetDict.save_to_disk`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.DatasetDict.save_to_disk). @@ -1418,7 +1598,20 @@ def save_dataset_dict_to_disk( Returns: None """ - hf_dataset_dict.save_to_disk(str(path), *args, **kwargs) + num_shards = _compute_num_shards(hf_dataset_dict) + num_proc = kwargs.get("num_proc", None) + if num_proc is not None: # pragma: no cover + min_num_shards = min(num_shards.values()) + if min_num_shards < num_proc: + logger.warning( + f"num_proc chaged from {num_proc} to 1 to safely adapt for num_shards={num_shards}" + ) + num_proc = 1 + del kwargs["num_proc"] + + hf_dataset_dict.save_to_disk( + str(path), num_shards=num_shards, num_proc=num_proc, **kwargs + ) def save_infos_to_disk( @@ -1821,7 +2014,7 @@ def huggingface_description_to_problem_definition( try: func(description[key]) except KeyError: - logger.info(f"Could not retrieve key:'{key}' from description") + logger.error(f"Could not retrieve key:'{key}' from description") pass return problem_definition diff --git a/tests/bridges/test_huggingface_bridge.py b/tests/bridges/test_huggingface_bridge.py index 6b6b7d37..aab5b0dc 100644 --- a/tests/bridges/test_huggingface_bridge.py +++ b/tests/bridges/test_huggingface_bridge.py @@ -58,19 +58,6 @@ def generator_(): return generator_ -@pytest.fixture() -def generator_split(dataset, problem_definition) -> dict[str, Callable]: - generators_ = {} - for split_name, ids in problem_definition.get_split().items(): - - def generator_(ids=ids): - for id in ids: - yield dataset[id] - - generators_[split_name] = generator_ - return generators_ - - @pytest.fixture() def gen_kwargs(problem_definition) -> dict[str, dict]: gen_kwargs = {} @@ -81,7 +68,7 @@ def gen_kwargs(problem_definition) -> dict[str, dict]: @pytest.fixture() -def generator_split_parallel(dataset, gen_kwargs) -> dict[str, Callable]: +def generator_split(dataset, gen_kwargs) -> dict[str, Callable]: generators_ = {} for split_name in gen_kwargs.keys(): @@ -181,10 +168,10 @@ def test_with_datasetdict(self, dataset, problem_definition): dataset[0].get_mesh(), dataset[0].get_mesh() ) - def test_with_generator(self, generator_split): + def test_with_generator(self, generator_split, gen_kwargs): hf_dataset_dict, flat_cst, key_mappings = ( huggingface_bridge.plaid_generator_to_huggingface_datasetdict( - generator_split + generator_split, gen_kwargs ) ) huggingface_bridge.to_plaid_sample( @@ -198,21 +185,16 @@ def test_with_generator(self, generator_split): enforce_shapes=True, ) - def test_with_generator_parallel(self, gen_kwargs, generator_split_parallel): - huggingface_bridge.plaid_generator_to_huggingface_datasetdict( - generator_split_parallel, processes_number=2, gen_kwargs=gen_kwargs - ) - # ------------------------------------------------------------------------------ # HUGGING FACE INTERACTIONS ON DISK # ------------------------------------------------------------------------------ def test_save_load_to_disk( - self, current_directory, generator_split, infos, problem_definition + self, current_directory, generator_split, infos, problem_definition, gen_kwargs ): hf_dataset_dict, flat_cst, key_mappings = ( huggingface_bridge.plaid_generator_to_huggingface_datasetdict( - generator_split + generator_split, gen_kwargs ) ) From 9fbd75e4261e96f5ea4f30cde7daa9543533832c Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Sat, 8 Nov 2025 16:55:12 +0100 Subject: [PATCH 09/17] continue --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 69620e8b..e9263609 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- (huggingface bridge) from_generator now compatible with `gen_kwargs` for parallel support. +- (huggingface bridge) full parallel support in `from_generator`, with optimization of constant leaf detection (no large data communicated between processes). ### Fixes From 062d51e1d75f856da1a350e71cb27384f892f2ac Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Sat, 8 Nov 2025 16:59:52 +0100 Subject: [PATCH 10/17] continue --- src/plaid/bridges/huggingface_bridge.py | 398 ++++++++++++------------ 1 file changed, 199 insertions(+), 199 deletions(-) diff --git a/src/plaid/bridges/huggingface_bridge.py b/src/plaid/bridges/huggingface_bridge.py index 746f165e..be16e7c1 100644 --- a/src/plaid/bridges/huggingface_bridge.py +++ b/src/plaid/bridges/huggingface_bridge.py @@ -545,237 +545,51 @@ def preprocess_splits( ) -# def _generator_prepare_for_huggingface( -# generators: dict[str, Callable], -# gen_kwargs: dict, -# processes_number: int = 1, -# verbose: bool = True, -# ): -# ( -# split_all_paths, -# split_flat_cst, -# split_var_path, -# global_cgns_types, -# global_feature_types, -# ) = preprocess_splits(generators, gen_kwargs, processes_number, verbose) - -# # --- build HF features --- -# var_features = sorted(list(set().union(*split_var_path.values()))) -# if len(var_features) == 0: # pragma: no cover -# raise ValueError( -# "no variable feature found, is your dataset variable through samples?" -# ) - -# for split_name in split_flat_cst.keys(): -# for path in var_features: -# if not path.endswith("_times") and path not in split_all_paths[split_name]: -# split_flat_cst[split_name][path + "_times"] = None # pragma: no cover -# if path in split_flat_cst[split_name]: -# split_flat_cst[split_name].pop(path) # pragma: no cover - -# cst_features = { -# split_name: sorted(list(cst.keys())) -# for split_name, cst in split_flat_cst.items() -# } -# first_split, first_value = next(iter(cst_features.items()), (None, None)) -# for split, value in cst_features.items(): -# assert value == first_value, ( -# f"cst_features differ for split '{split}' (vs '{first_split}')" -# ) -# cst_features = first_value - -# hf_features_map = {} -# for k in var_features: -# if k.endswith("_times"): -# hf_features_map[k] = Sequence(Value("float64")) # pragma: no cover -# else: -# hf_features_map[k] = global_feature_types[k] -# hf_features = Features(hf_features_map) - -# var_features = [path for path in var_features if not path.endswith("_times")] -# cst_features = [path for path in cst_features if not path.endswith("_times")] - -# key_mappings = { -# "variable_features": var_features, -# "constant_features": cst_features, -# "cgns_types": global_cgns_types, -# } - -# return split_flat_cst, key_mappings, hf_features - - -# ------------------------------------------- -# --------- Sequential version def _generator_prepare_for_huggingface( generators: dict[str, Callable], gen_kwargs: dict, processes_number: int = 1, verbose: bool = True, -) -> tuple[dict[str, dict[str, Any]], dict[str, Any], Features]: - """Inspect PLAID dataset generators and infer Hugging Face feature schema. - - Iterates over all samples in all provided split generators to: - 1. Flatten each CGNS tree into a dictionary of paths → values. - 2. Infer Hugging Face `Features` types for all variable leaves. - 3. Detect constant leaves (values that never change across all samples). - 4. Collect global CGNS type metadata. - - Args: - generators (dict[str, Callable]): - Mapping from split names to callables returning sample generators. - Each sample must have `sample.features.data[0.0]` compatible with `flatten_cgns_tree`. - gen_kwargs (dict, optional, default=None): - Optional mapping from split names to dictionaries of keyword arguments - to be passed to each generator function, used for parallelization. - processes_number (int, optional): Number of parallel processes to use. - verbose (bool, optional): If True, displays progress bars while processing splits. - - Returns: - tuple: - - flat_cst (dict[str, Any]): Mapping from feature path to constant values detected across all splits. - - key_mappings (dict[str, Any]): Metadata dictionary with: - - "variable_features" (list[str]): paths of non-constant features. - - "constant_features" (list[str]): paths of constant features. - - "cgns_types" (dict[str, Any]): CGNS type information for all paths. - - hf_features (datasets.Features): Hugging Face feature specification for variable features. - - Raises: - ValueError: If inconsistent CGNS types or feature types are found for the same path. - """ - processes_number - - def values_equal(v1, v2): - if isinstance(v1, np.ndarray) and isinstance(v2, np.ndarray): - return np.array_equal(v1, v2) - return v1 == v2 - - global_cgns_types = {} - global_feature_types = {} - - split_flat_cst = {} - split_var_path = {} - split_all_paths = {} - - # ---- Single pass over all splits and samples ---- - for split_name, generator in generators.items(): - split_constant_leaves = {} - - split_all_paths[split_name] = set() - - n_samples = 0 - for sample in tqdm( - generator(**gen_kwargs[split_name]), - disable=not verbose, - desc=f"Pre-process split {split_name}", - ): - # --- Build Hugging Face–compatible sample --- - hf_sample, all_paths, sample_cgns_types = build_hf_sample(sample) - - split_all_paths[split_name].update(hf_sample.keys()) - # split_all_paths[split_name].update(all_paths) - global_cgns_types.update(sample_cgns_types) - - # --- Infer global HF feature types --- - for path in all_paths: - value = hf_sample[path] - if value is None: - continue - - # if isinstance(value, np.ndarray) and value.dtype.type is np.str_: - # inferred = Value("string") - # else: - # inferred = infer_hf_features_from_value(value) - - inferred = infer_hf_features_from_value(value) - - if path not in global_feature_types: - global_feature_types[path] = inferred - elif repr(global_feature_types[path]) != repr(inferred): - raise ValueError( # pragma: no cover - f"Feature type mismatch for {path} in split {split_name}" - ) - - # --- Update per-split constant detection --- - for path, value in hf_sample.items(): - if path not in split_constant_leaves: - split_constant_leaves[path] = { - "value": value, - "constant": True, - "count": 1, - } - else: - entry = split_constant_leaves[path] - entry["count"] += 1 - if entry["constant"] and not values_equal(entry["value"], value): - entry["constant"] = False - - n_samples += 1 - - # --- Record per-split constants --- - for p, e in split_constant_leaves.items(): - if e["count"] < n_samples: - split_constant_leaves[p]["constant"] = False - - split_flat_cst[split_name] = dict( - sorted( - ( - (p, e["value"]) - for p, e in split_constant_leaves.items() - if e["constant"] - ), - key=lambda x: x[0], - ) - ) - - split_var_path[split_name] = { - p - for p in split_all_paths[split_name] - if p not in split_flat_cst[split_name] - } +): + ( + split_all_paths, + split_flat_cst, + split_var_path, + global_cgns_types, + global_feature_types, + ) = preprocess_splits(generators, gen_kwargs, processes_number, verbose) - global_feature_types = { - p: global_feature_types[p] for p in sorted(global_feature_types) - } + # --- build HF features --- var_features = sorted(list(set().union(*split_var_path.values()))) - - if len(var_features) == 0: - raise ValueError( # pragma: no cover + if len(var_features) == 0: # pragma: no cover + raise ValueError( "no variable feature found, is your dataset variable through samples?" ) - # --------------------------------------------------- - # for test-like splits, some var_features are all None (e.g.: outputs): need to add '_times' counterparts to corresponding constant trees for split_name in split_flat_cst.keys(): for path in var_features: if not path.endswith("_times") and path not in split_all_paths[split_name]: split_flat_cst[split_name][path + "_times"] = None # pragma: no cover - if ( - path in split_flat_cst[split_name] - ): # remove for flat_cst the path that will be forcely included in the arrow tables + if path in split_flat_cst[split_name]: split_flat_cst[split_name].pop(path) # pragma: no cover - # ---- Constant features sanity check cst_features = { split_name: sorted(list(cst.keys())) for split_name, cst in split_flat_cst.items() } - first_split, first_value = next(iter(cst_features.items()), (None, None)) for split, value in cst_features.items(): assert value == first_value, ( - f"cst_features differ for split '{split}' (vs '{first_split}'): something went wrong in _generator_prepare_for_huggingface." + f"cst_features differ for split '{split}' (vs '{first_split}')" ) - cst_features = first_value - # ---- Build global HF Features (only variable) ---- hf_features_map = {} for k in var_features: if k.endswith("_times"): hf_features_map[k] = Sequence(Value("float64")) # pragma: no cover else: hf_features_map[k] = global_feature_types[k] - hf_features = Features(hf_features_map) var_features = [path for path in var_features if not path.endswith("_times")] @@ -790,6 +604,192 @@ def values_equal(v1, v2): return split_flat_cst, key_mappings, hf_features +# # ------------------------------------------- +# # --------- Sequential version +# def _generator_prepare_for_huggingface( +# generators: dict[str, Callable], +# gen_kwargs: dict, +# processes_number: int = 1, +# verbose: bool = True, +# ) -> tuple[dict[str, dict[str, Any]], dict[str, Any], Features]: +# """Inspect PLAID dataset generators and infer Hugging Face feature schema. + +# Iterates over all samples in all provided split generators to: +# 1. Flatten each CGNS tree into a dictionary of paths → values. +# 2. Infer Hugging Face `Features` types for all variable leaves. +# 3. Detect constant leaves (values that never change across all samples). +# 4. Collect global CGNS type metadata. + +# Args: +# generators (dict[str, Callable]): +# Mapping from split names to callables returning sample generators. +# Each sample must have `sample.features.data[0.0]` compatible with `flatten_cgns_tree`. +# gen_kwargs (dict, optional, default=None): +# Optional mapping from split names to dictionaries of keyword arguments +# to be passed to each generator function, used for parallelization. +# processes_number (int, optional): Number of parallel processes to use. +# verbose (bool, optional): If True, displays progress bars while processing splits. + +# Returns: +# tuple: +# - flat_cst (dict[str, Any]): Mapping from feature path to constant values detected across all splits. +# - key_mappings (dict[str, Any]): Metadata dictionary with: +# - "variable_features" (list[str]): paths of non-constant features. +# - "constant_features" (list[str]): paths of constant features. +# - "cgns_types" (dict[str, Any]): CGNS type information for all paths. +# - hf_features (datasets.Features): Hugging Face feature specification for variable features. + +# Raises: +# ValueError: If inconsistent CGNS types or feature types are found for the same path. +# """ +# processes_number + +# def values_equal(v1, v2): +# if isinstance(v1, np.ndarray) and isinstance(v2, np.ndarray): +# return np.array_equal(v1, v2) +# return v1 == v2 + +# global_cgns_types = {} +# global_feature_types = {} + +# split_flat_cst = {} +# split_var_path = {} +# split_all_paths = {} + +# # ---- Single pass over all splits and samples ---- +# for split_name, generator in generators.items(): +# split_constant_leaves = {} + +# split_all_paths[split_name] = set() + +# n_samples = 0 +# for sample in tqdm( +# generator(**gen_kwargs[split_name]), +# disable=not verbose, +# desc=f"Pre-process split {split_name}", +# ): +# # --- Build Hugging Face–compatible sample --- +# hf_sample, all_paths, sample_cgns_types = build_hf_sample(sample) + +# split_all_paths[split_name].update(hf_sample.keys()) +# # split_all_paths[split_name].update(all_paths) +# global_cgns_types.update(sample_cgns_types) + +# # --- Infer global HF feature types --- +# for path in all_paths: +# value = hf_sample[path] +# if value is None: +# continue + +# # if isinstance(value, np.ndarray) and value.dtype.type is np.str_: +# # inferred = Value("string") +# # else: +# # inferred = infer_hf_features_from_value(value) + +# inferred = infer_hf_features_from_value(value) + +# if path not in global_feature_types: +# global_feature_types[path] = inferred +# elif repr(global_feature_types[path]) != repr(inferred): +# raise ValueError( # pragma: no cover +# f"Feature type mismatch for {path} in split {split_name}" +# ) + +# # --- Update per-split constant detection --- +# for path, value in hf_sample.items(): +# if path not in split_constant_leaves: +# split_constant_leaves[path] = { +# "value": value, +# "constant": True, +# "count": 1, +# } +# else: +# entry = split_constant_leaves[path] +# entry["count"] += 1 +# if entry["constant"] and not values_equal(entry["value"], value): +# entry["constant"] = False + +# n_samples += 1 + +# # --- Record per-split constants --- +# for p, e in split_constant_leaves.items(): +# if e["count"] < n_samples: +# split_constant_leaves[p]["constant"] = False + +# split_flat_cst[split_name] = dict( +# sorted( +# ( +# (p, e["value"]) +# for p, e in split_constant_leaves.items() +# if e["constant"] +# ), +# key=lambda x: x[0], +# ) +# ) + +# split_var_path[split_name] = { +# p +# for p in split_all_paths[split_name] +# if p not in split_flat_cst[split_name] +# } + +# global_feature_types = { +# p: global_feature_types[p] for p in sorted(global_feature_types) +# } +# var_features = sorted(list(set().union(*split_var_path.values()))) + +# if len(var_features) == 0: +# raise ValueError( # pragma: no cover +# "no variable feature found, is your dataset variable through samples?" +# ) + +# # --------------------------------------------------- +# # for test-like splits, some var_features are all None (e.g.: outputs): need to add '_times' counterparts to corresponding constant trees +# for split_name in split_flat_cst.keys(): +# for path in var_features: +# if not path.endswith("_times") and path not in split_all_paths[split_name]: +# split_flat_cst[split_name][path + "_times"] = None # pragma: no cover +# if ( +# path in split_flat_cst[split_name] +# ): # remove for flat_cst the path that will be forcely included in the arrow tables +# split_flat_cst[split_name].pop(path) # pragma: no cover + +# # ---- Constant features sanity check +# cst_features = { +# split_name: sorted(list(cst.keys())) +# for split_name, cst in split_flat_cst.items() +# } + +# first_split, first_value = next(iter(cst_features.items()), (None, None)) +# for split, value in cst_features.items(): +# assert value == first_value, ( +# f"cst_features differ for split '{split}' (vs '{first_split}'): something went wrong in _generator_prepare_for_huggingface." +# ) + +# cst_features = first_value + +# # ---- Build global HF Features (only variable) ---- +# hf_features_map = {} +# for k in var_features: +# if k.endswith("_times"): +# hf_features_map[k] = Sequence(Value("float64")) # pragma: no cover +# else: +# hf_features_map[k] = global_feature_types[k] + +# hf_features = Features(hf_features_map) + +# var_features = [path for path in var_features if not path.endswith("_times")] +# cst_features = [path for path in cst_features if not path.endswith("_times")] + +# key_mappings = { +# "variable_features": var_features, +# "constant_features": cst_features, +# "cgns_types": global_cgns_types, +# } + +# return split_flat_cst, key_mappings, hf_features + + def to_plaid_dataset( hf_dataset: datasets.Dataset, flat_cst: dict[str, Any], From 93d2a802363e294bee3c3ac1c5cd8b16aafa57c5 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Sat, 8 Nov 2025 18:28:31 +0100 Subject: [PATCH 11/17] continue --- examples/bridges/huggingface_example.py | 19 ++++-- src/plaid/bridges/huggingface_bridge.py | 86 +++++++------------------ src/plaid/containers/dataset.py | 4 +- 3 files changed, 41 insertions(+), 68 deletions(-) diff --git a/examples/bridges/huggingface_example.py b/examples/bridges/huggingface_example.py index 1c51dbb6..722f7174 100644 --- a/examples/bridges/huggingface_example.py +++ b/examples/bridges/huggingface_example.py @@ -153,16 +153,25 @@ def get_mem(): # Ganarators are used to handle large datasets that do not fit in memory: # %% +gen_kwargs = {} +gen_kwargs["train"] = {"shards_ids": [[0, 1]]} +gen_kwargs["test"] = {"shards_ids": [[2]]} + generators = {} -for split_name, ids in main_splits.items(): - def generator_(ids=ids): - for id in ids: - yield dataset[id] +for split_name in gen_kwargs.keys(): + + def generator_(shards_ids): + for ids in shards_ids: + if isinstance(ids, int): + ids = [ids] + for id in ids: + yield dataset[id] + generators[split_name] = generator_ hf_datasetdict, flat_cst, key_mappings = ( huggingface_bridge.plaid_generator_to_huggingface_datasetdict( - generators + generators, gen_kwargs ) ) print(f"{hf_datasetdict = }") diff --git a/src/plaid/bridges/huggingface_bridge.py b/src/plaid/bridges/huggingface_bridge.py index be16e7c1..89b3cb6d 100644 --- a/src/plaid/bridges/huggingface_bridge.py +++ b/src/plaid/bridges/huggingface_bridge.py @@ -13,7 +13,6 @@ import pickle import shutil import sys -import traceback from functools import partial from multiprocessing import Pool from pathlib import Path @@ -294,6 +293,7 @@ def process_shard( split_all_paths = set() shard_global_cgns_types = {} shard_global_feature_types = {} + shards_to_process = [shard_ids] for sample in generator_fn(shards_to_process): @@ -310,25 +310,22 @@ def process_shard( inferred = infer_hf_features_from_value(value) if path not in shard_global_feature_types: shard_global_feature_types[path] = inferred - elif repr(shard_global_feature_types[path]) != repr( - inferred - ): # pragma: no cover + elif repr(shard_global_feature_types[path]) != repr(inferred): raise ValueError(f"Feature type mismatch for {path} in shard") - # Constant detection using hashes + # Constant detection using **hash only** for path, value in hf_sample.items(): h = _hash_value(value) if path not in split_constant_leaves: - split_constant_leaves[path] = {"hash": h, "constant": True, "count": 1} + split_constant_leaves[path] = {"hashes": {h}, "count": 1} else: entry = split_constant_leaves[path] + entry["hashes"].add(h) entry["count"] += 1 - if entry["constant"] and entry["hash"] != h: - entry["constant"] = False - # --- Update progress --- + # Progress if n_proc > 1: - progress.put(1) # pragma: no cover + progress.put(1) else: progress.update(1) @@ -399,8 +396,7 @@ def preprocess_splits( for split_name, generator_fn in generators.items(): shards_ids_list = gen_kwargs[split_name].get("shards_ids", [None]) - n_proc = processes_number or len(shards_ids_list) - n_proc = max(1, n_proc) + n_proc = max(1, processes_number or len(shards_ids_list)) shards_data = [] @@ -413,34 +409,20 @@ def preprocess_splits( ) as pbar: for shard_ids in shards_ids_list: shards_data.append( - process_shard( - shard_ids, - generator_fn, - pbar, - n_proc, - ) + process_shard(shard_ids, generator_fn, pbar, n_proc) ) - else: # pragma: no cover (pytest not working with parallel mode) - # --- Parallel execution --- - manager = None - pool = None + else: + # Parallel execution + manager = mp.Manager() + progress_queue = manager.Queue() try: - manager = mp.Manager() - progress_queue = manager.Queue() - - # --- Run shards in parallel --- with mp.Pool(n_proc) as pool: results = [ pool.apply_async( process_shard, - args=( - shard_ids, - generator_fn, - progress_queue, - n_proc, - ), + args=(shard_ids, generator_fn, progress_queue, n_proc), ) for shard_ids in shards_ids_list ] @@ -458,27 +440,12 @@ def preprocess_splits( completed += increment for r in results: - try: - data = r.get() # this raises if the worker crashed - shards_data.append(data) - except Exception: - traceback.print_exc() - # Optional: terminate pool early if one shard fails - pool.terminate() - raise - - except Exception: - traceback.print_exc() - raise + shards_data.append(r.get()) + finally: - # Always clean up multiprocessing objects - if pool is not None: - pool.terminate() - pool.join() - if manager is not None: - manager.shutdown() - - # --- Merge shard results --- + manager.shutdown() + + # Merge shard results split_all_paths[split_name] = set() split_constant_hashes = {} n_samples_total = 0 @@ -496,9 +463,7 @@ def preprocess_splits( for path, inferred in shard_features.items(): if path not in global_feature_types: global_feature_types[path] = inferred - elif repr(global_feature_types[path]) != repr( - inferred - ): # pragma: no cover + elif repr(global_feature_types[path]) != repr(inferred): raise ValueError( f"Feature type mismatch for {path} in split {split_name}" ) @@ -508,20 +473,19 @@ def preprocess_splits( split_constant_hashes[path] = entry else: existing = split_constant_hashes[path] - existing["constant"] = existing["constant"] and entry["constant"] + existing["hashes"].update(entry["hashes"]) existing["count"] += entry["count"] n_samples_total += n_samples - # --- Finalize constants by inspecting first sample --- - # Only paths marked constant across all samples + # Determine truly constant paths (same hash across all samples) constant_paths = [ p - for p, e in split_constant_hashes.items() - if e["constant"] and e["count"] == n_samples_total + for p, entry in split_constant_hashes.items() + if len(entry["hashes"]) == 1 and entry["count"] == n_samples_total ] - # Inspect first sample to get actual values + # Retrieve **values** only for constant paths from first sample first_sample = next(generator_fn([shards_ids_list[0]])) hf_sample, _, _ = build_hf_sample(first_sample) diff --git a/src/plaid/containers/dataset.py b/src/plaid/containers/dataset.py index 45f80f2a..c0dfc2bb 100644 --- a/src/plaid/containers/dataset.py +++ b/src/plaid/containers/dataset.py @@ -963,8 +963,8 @@ def set_infos(self, infos: dict[str, dict[str, str]]) -> None: f"{info_key=} not among authorized keys. Maybe you want to try among these keys {AUTHORIZED_INFO_KEYS[cat_key]}" ) - if len(self._infos) > 0: - logger.warning("infos not empty, replacing it anyway") + # if len(self._infos) > 0: + # logger.warning("infos not empty, replacing it anyway") self._infos = copy.deepcopy(infos) if "plaid" not in self._infos: From cbf44c41059f2b83606e6535cec960646f2ddedc Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Sat, 8 Nov 2025 20:30:56 +0100 Subject: [PATCH 12/17] continue --- src/plaid/bridges/huggingface_bridge.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/plaid/bridges/huggingface_bridge.py b/src/plaid/bridges/huggingface_bridge.py index 89b3cb6d..4ea29d4e 100644 --- a/src/plaid/bridges/huggingface_bridge.py +++ b/src/plaid/bridges/huggingface_bridge.py @@ -311,7 +311,9 @@ def process_shard( if path not in shard_global_feature_types: shard_global_feature_types[path] = inferred elif repr(shard_global_feature_types[path]) != repr(inferred): - raise ValueError(f"Feature type mismatch for {path} in shard") + raise ValueError( + f"Feature type mismatch for {path} in shard" + ) # pragma: no cover # Constant detection using **hash only** for path, value in hf_sample.items(): @@ -325,7 +327,7 @@ def process_shard( # Progress if n_proc > 1: - progress.put(1) + progress.put(1) # pragma: no cover else: progress.update(1) @@ -412,7 +414,7 @@ def preprocess_splits( process_shard(shard_ids, generator_fn, pbar, n_proc) ) - else: + else: # pragma: no cover # Parallel execution manager = mp.Manager() progress_queue = manager.Queue() @@ -464,7 +466,7 @@ def preprocess_splits( if path not in global_feature_types: global_feature_types[path] = inferred elif repr(global_feature_types[path]) != repr(inferred): - raise ValueError( + raise ValueError( # pragma: no cover f"Feature type mismatch for {path} in split {split_name}" ) From 28f0a41ac51c97f1240ca2e8385cc7f10cbfe0b6 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Sun, 9 Nov 2025 10:09:25 +0100 Subject: [PATCH 13/17] continue --- src/plaid/bridges/huggingface_bridge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plaid/bridges/huggingface_bridge.py b/src/plaid/bridges/huggingface_bridge.py index 89b3cb6d..2a39515d 100644 --- a/src/plaid/bridges/huggingface_bridge.py +++ b/src/plaid/bridges/huggingface_bridge.py @@ -47,7 +47,7 @@ ) logger = logging.getLogger(__name__) - +pa.set_memory_pool(pa.system_memory_pool()) # ------------------------------------------------------------------------------ # HUGGING FACE BRIDGE (with tree flattening and pyarrow tables) From c4fc1c2d3b4a55e04184a6a3c4be93989f48ad57 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Fri, 14 Nov 2025 18:14:29 +0100 Subject: [PATCH 14/17] continue --- src/plaid/bridges/huggingface_bridge.py | 68 +++++++++++++++--------- tests/bridges/test_huggingface_bridge.py | 37 ++++++++++--- 2 files changed, 73 insertions(+), 32 deletions(-) diff --git a/src/plaid/bridges/huggingface_bridge.py b/src/plaid/bridges/huggingface_bridge.py index 42c05935..ac997002 100644 --- a/src/plaid/bridges/huggingface_bridge.py +++ b/src/plaid/bridges/huggingface_bridge.py @@ -246,10 +246,10 @@ def _hash_value(value): def process_shard( - shard_ids: list[IndexType], - generator_fn: Callable[[list[list[IndexType]]], Any], + generator_fn: Callable[..., Any], progress: Any, n_proc: int, + shard_ids: Optional[list[IndexType]] = None, ) -> tuple[ set[str], dict[str, str], @@ -294,9 +294,15 @@ def process_shard( shard_global_cgns_types = {} shard_global_feature_types = {} - shards_to_process = [shard_ids] + if shard_ids: + generator = generator_fn([shard_ids]) + else: + generator = generator_fn() - for sample in generator_fn(shards_to_process): + n_samples = 0 + for sample in generator: + print("test_field_2785 =", sample.get_field("test_field_2785")) + print("shard_ids =", sample.get_field("test_field_2785")) hf_sample, all_paths, sample_cgns_types = build_hf_sample(sample) split_all_paths.update(hf_sample.keys()) @@ -331,18 +337,20 @@ def process_shard( else: progress.update(1) + n_samples += 1 + return ( split_all_paths, shard_global_cgns_types, shard_global_feature_types, split_constant_leaves, - len(shard_ids), + n_samples, ) def preprocess_splits( generators: dict[str, Callable], - gen_kwargs: dict[str, dict[str, list[IndexType]]], + gen_kwargs: Optional[dict[str, dict[str, list[IndexType]]]] = None, processes_number: int = 1, verbose: bool = True, ) -> tuple[ @@ -396,22 +404,23 @@ def preprocess_splits( split_var_path = {} split_all_paths = {} + gen_kwargs_ = gen_kwargs or {split_name: {} for split_name in generators.keys()} + for split_name, generator_fn in generators.items(): - shards_ids_list = gen_kwargs[split_name].get("shards_ids", [None]) + shards_ids_list = gen_kwargs_[split_name].get("shards_ids", [None]) n_proc = max(1, processes_number or len(shards_ids_list)) + print("shards_ids_list =", shards_ids_list) shards_data = [] if n_proc == 1: - progress_total = sum(len(shard) for shard in shards_ids_list) with tqdm( - total=progress_total, disable=not verbose, desc=f"Pre-process split {split_name}", ) as pbar: for shard_ids in shards_ids_list: shards_data.append( - process_shard(shard_ids, generator_fn, pbar, n_proc) + process_shard(generator_fn, pbar, n_proc=1, shard_ids=shard_ids) ) else: # pragma: no cover @@ -424,7 +433,7 @@ def preprocess_splits( results = [ pool.apply_async( process_shard, - args=(shard_ids, generator_fn, progress_queue, n_proc), + args=(generator_fn, progress_queue, n_proc, shard_ids), ) for shard_ids in shards_ids_list ] @@ -488,7 +497,10 @@ def preprocess_splits( ] # Retrieve **values** only for constant paths from first sample - first_sample = next(generator_fn([shards_ids_list[0]])) + if gen_kwargs: + first_sample = next(generator_fn([shards_ids_list[0]])) + else: + first_sample = next(generator_fn()) hf_sample, _, _ = build_hf_sample(first_sample) split_flat_cst[split_name] = {p: hf_sample[p] for p in sorted(constant_paths)} @@ -513,7 +525,7 @@ def preprocess_splits( def _generator_prepare_for_huggingface( generators: dict[str, Callable], - gen_kwargs: dict, + gen_kwargs: Optional[dict[str, dict[str, list[IndexType]]]] = None, processes_number: int = 1, verbose: bool = True, ): @@ -951,27 +963,30 @@ def plaid_dataset_to_huggingface_datasetdict( }) """ - def generator_(shards_ids): - for ids in shards_ids: - if isinstance(ids, int): - ids = [ids] # pragma: no cover - for id in ids: - yield dataset[id] + def generator(dataset): + for sample in dataset: + yield sample - generators = {split_name: generator_ for split_name in main_splits.keys()} - - gen_kwargs = { - split_name: {"shards_ids": [ids]} for split_name, ids in main_splits.items() + generators = { + split_name: partial(generator, dataset[ids]) + for split_name, ids in main_splits.items() } + # gen_kwargs = { + # split_name: {"shards_ids": [ids]} for split_name, ids in main_splits.items() + # } + return plaid_generator_to_huggingface_datasetdict( - generators, gen_kwargs, processes_number, writer_batch_size, verbose + generators, + processes_number=processes_number, + writer_batch_size=writer_batch_size, + verbose=verbose, ) def plaid_generator_to_huggingface_datasetdict( generators: dict[str, Callable], - gen_kwargs: dict[str, dict[str, list[IndexType]]], + gen_kwargs: Optional[dict[str, dict[str, list[IndexType]]]] = None, processes_number: int = 1, writer_batch_size: int = 1, verbose: bool = False, @@ -1047,9 +1062,10 @@ def generator_fn(gen_func, all_features_keys, **kwargs): _dict = {} for split_name, gen_func in generators.items(): gen = partial(generator_fn, all_features_keys=all_features_keys) + gen_kwargs_ = gen_kwargs or {split_name: {} for split_name in generators.keys()} _dict[split_name] = datasets.Dataset.from_generator( generator=gen, - gen_kwargs={"gen_func": gen_func, **gen_kwargs[split_name]}, + gen_kwargs={"gen_func": gen_func, **gen_kwargs_[split_name]}, features=hf_features, num_proc=processes_number, writer_batch_size=writer_batch_size, diff --git a/tests/bridges/test_huggingface_bridge.py b/tests/bridges/test_huggingface_bridge.py index aab5b0dc..2e67f09e 100644 --- a/tests/bridges/test_huggingface_bridge.py +++ b/tests/bridges/test_huggingface_bridge.py @@ -68,7 +68,24 @@ def gen_kwargs(problem_definition) -> dict[str, dict]: @pytest.fixture() -def generator_split(dataset, gen_kwargs) -> dict[str, Callable]: +def generator_split(dataset, problem_definition) -> dict[str, Callable]: + generators_ = {} + + main_splits = problem_definition.get_split() + + for split_name, ids in main_splits.items(): + + def generator_(): + for id in ids: + yield dataset[id] + + generators_[split_name] = generator_ + + return generators_ + + +@pytest.fixture() +def generator_split_with_kwargs(dataset, gen_kwargs) -> dict[str, Callable]: generators_ = {} for split_name in gen_kwargs.keys(): @@ -101,7 +118,7 @@ def generator_split_binary(dataset, problem_definition) -> dict[str, Callable]: generators_ = {} for split_name, ids in problem_definition.get_split().items(): - def generator_(ids=ids): + def generator_(): for id in ids: yield {"sample": pickle.dumps(dataset[id])} @@ -144,6 +161,7 @@ def test_with_datasetdict(self, dataset, problem_definition): dataset, main_splits ) ) + huggingface_bridge.to_plaid_sample( hf_dataset_dict["train"], 0, flat_cst["train"], key_mappings["cgns_types"] ) @@ -168,10 +186,17 @@ def test_with_datasetdict(self, dataset, problem_definition): dataset[0].get_mesh(), dataset[0].get_mesh() ) - def test_with_generator(self, generator_split, gen_kwargs): + def test_with_generator( + self, generator_split_with_kwargs, generator_split, gen_kwargs + ): + hf_dataset_dict, flat_cst, key_mappings = ( + huggingface_bridge.plaid_generator_to_huggingface_datasetdict( + generator_split_with_kwargs, gen_kwargs + ) + ) hf_dataset_dict, flat_cst, key_mappings = ( huggingface_bridge.plaid_generator_to_huggingface_datasetdict( - generator_split, gen_kwargs + generator_split ) ) huggingface_bridge.to_plaid_sample( @@ -190,11 +215,11 @@ def test_with_generator(self, generator_split, gen_kwargs): # ------------------------------------------------------------------------------ def test_save_load_to_disk( - self, current_directory, generator_split, infos, problem_definition, gen_kwargs + self, current_directory, generator_split, infos, problem_definition ): hf_dataset_dict, flat_cst, key_mappings = ( huggingface_bridge.plaid_generator_to_huggingface_datasetdict( - generator_split, gen_kwargs + generator_split ) ) From 43bc647f65aa35a8e19df41f21c96ae6cf7bf432 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Thu, 20 Nov 2025 20:59:21 +0100 Subject: [PATCH 15/17] continue --- src/plaid/bridges/huggingface_bridge.py | 45 +++++++++++++++++-------- 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/src/plaid/bridges/huggingface_bridge.py b/src/plaid/bridges/huggingface_bridge.py index ac997002..81b6b00c 100644 --- a/src/plaid/bridges/huggingface_bridge.py +++ b/src/plaid/bridges/huggingface_bridge.py @@ -17,6 +17,8 @@ from multiprocessing import Pool from pathlib import Path from typing import Callable, Optional +from queue import Empty +import traceback import numpy as np import pyarrow as pa @@ -294,15 +296,13 @@ def process_shard( shard_global_cgns_types = {} shard_global_feature_types = {} - if shard_ids: + if shard_ids is not None: generator = generator_fn([shard_ids]) else: generator = generator_fn() n_samples = 0 for sample in generator: - print("test_field_2785 =", sample.get_field("test_field_2785")) - print("shard_ids =", sample.get_field("test_field_2785")) hf_sample, all_paths, sample_cgns_types = build_hf_sample(sample) split_all_paths.update(hf_sample.keys()) @@ -348,6 +348,15 @@ def process_shard( ) +def _process_shard_debug(generator_fn, progress_queue, n_proc, shard_ids): + try: + return process_shard(generator_fn, progress_queue, n_proc, shard_ids) + except Exception as e: + print(f"Exception in worker for shards {shard_ids}: {e}", file=sys.stderr) + traceback.print_exc() + raise # re-raise to propagate to main process + + def preprocess_splits( generators: dict[str, Callable], gen_kwargs: Optional[dict[str, dict[str, list[IndexType]]]] = None, @@ -427,29 +436,37 @@ def preprocess_splits( # Parallel execution manager = mp.Manager() progress_queue = manager.Queue() + shards_data = [] try: with mp.Pool(n_proc) as pool: results = [ pool.apply_async( - process_shard, + _process_shard_debug, args=(generator_fn, progress_queue, n_proc, shard_ids), ) for shard_ids in shards_ids_list ] total_samples = sum(len(shard) for shard in shards_ids_list) - with tqdm( - total=total_samples, - disable=not verbose, - desc=f"Pre-process split {split_name}", - ) as pbar: - completed = 0 - while completed < total_samples: - increment = progress_queue.get() - pbar.update(increment) - completed += increment + completed = 0 + with tqdm(total=total_samples, disable=not verbose, desc=f"Pre-process split {split_name}") as pbar: + while completed < total_samples: + try: + increment = progress_queue.get(timeout=0.5) + pbar.update(increment) + completed += increment + except Empty: + # Check for any crashed workers + for r in results: + if r.ready(): + try: + r.get(timeout=0.1) # will raise worker exception if any + except Exception as e: + raise RuntimeError(f"Worker crashed: {e}") + + # Collect all results for r in results: shards_data.append(r.get()) From 06c246113be35b9898594c3fc35e4ed09832b457 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Thu, 20 Nov 2025 21:02:13 +0100 Subject: [PATCH 16/17] continue --- src/plaid/bridges/huggingface_bridge.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/plaid/bridges/huggingface_bridge.py b/src/plaid/bridges/huggingface_bridge.py index 81b6b00c..4e613577 100644 --- a/src/plaid/bridges/huggingface_bridge.py +++ b/src/plaid/bridges/huggingface_bridge.py @@ -13,12 +13,12 @@ import pickle import shutil import sys +import traceback from functools import partial from multiprocessing import Pool from pathlib import Path -from typing import Callable, Optional from queue import Empty -import traceback +from typing import Callable, Optional import numpy as np import pyarrow as pa @@ -451,7 +451,11 @@ def preprocess_splits( total_samples = sum(len(shard) for shard in shards_ids_list) completed = 0 - with tqdm(total=total_samples, disable=not verbose, desc=f"Pre-process split {split_name}") as pbar: + with tqdm( + total=total_samples, + disable=not verbose, + desc=f"Pre-process split {split_name}", + ) as pbar: while completed < total_samples: try: increment = progress_queue.get(timeout=0.5) @@ -462,7 +466,9 @@ def preprocess_splits( for r in results: if r.ready(): try: - r.get(timeout=0.1) # will raise worker exception if any + r.get( + timeout=0.1 + ) # will raise worker exception if any except Exception as e: raise RuntimeError(f"Worker crashed: {e}") From 4f5f793bde4b6c464fc1715d27b8503473067129 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Thu, 20 Nov 2025 21:10:05 +0100 Subject: [PATCH 17/17] continue --- src/plaid/bridges/huggingface_bridge.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/plaid/bridges/huggingface_bridge.py b/src/plaid/bridges/huggingface_bridge.py index 4e613577..41e87268 100644 --- a/src/plaid/bridges/huggingface_bridge.py +++ b/src/plaid/bridges/huggingface_bridge.py @@ -348,7 +348,9 @@ def process_shard( ) -def _process_shard_debug(generator_fn, progress_queue, n_proc, shard_ids): +def _process_shard_debug( + generator_fn, progress_queue, n_proc, shard_ids +): # pragma: no cover try: return process_shard(generator_fn, progress_queue, n_proc, shard_ids) except Exception as e: @@ -418,7 +420,6 @@ def preprocess_splits( for split_name, generator_fn in generators.items(): shards_ids_list = gen_kwargs_[split_name].get("shards_ids", [None]) n_proc = max(1, processes_number or len(shards_ids_list)) - print("shards_ids_list =", shards_ids_list) shards_data = []