diff --git a/CHANGELOG.md b/CHANGELOG.md index 1fd4fcd4..be6df207 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,8 +13,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- (huggingface bridge) full parallel support in `from_generator`, with optimization of constant leaf detection (no large data communicated between processes). + ### Fixes +- (samples/features) add string support to globals. +- (huggingface bridge) correct split_constant tree derivation, add heuristic for number of shards usage in push_to_dict, robustify infer_hf_features_from_value with respect to numpy arrays of strings, modernize update_dataset_card. + ### Removed ## [0.1.10] - 2025-10-29 diff --git a/examples/bridges/huggingface_example.py b/examples/bridges/huggingface_example.py index 1c51dbb6..722f7174 100644 --- a/examples/bridges/huggingface_example.py +++ b/examples/bridges/huggingface_example.py @@ -153,16 +153,25 @@ def get_mem(): # Ganarators are used to handle large datasets that do not fit in memory: # %% +gen_kwargs = {} +gen_kwargs["train"] = {"shards_ids": [[0, 1]]} +gen_kwargs["test"] = {"shards_ids": [[2]]} + generators = {} -for split_name, ids in main_splits.items(): - def generator_(ids=ids): - for id in ids: - yield dataset[id] +for split_name in gen_kwargs.keys(): + + def generator_(shards_ids): + for ids in shards_ids: + if isinstance(ids, int): + ids = [ids] + for id in ids: + yield dataset[id] + generators[split_name] = generator_ hf_datasetdict, flat_cst, key_mappings = ( huggingface_bridge.plaid_generator_to_huggingface_datasetdict( - generators + generators, gen_kwargs ) ) print(f"{hf_datasetdict = }") diff --git a/src/plaid/bridges/huggingface_bridge.py b/src/plaid/bridges/huggingface_bridge.py index dd1e71e6..42c05935 100644 --- a/src/plaid/bridges/huggingface_bridge.py +++ b/src/plaid/bridges/huggingface_bridge.py @@ -8,6 +8,7 @@ # import hashlib import io +import multiprocessing as mp import os import pickle import shutil @@ -44,10 +45,9 @@ flatten_cgns_tree, unflatten_cgns_tree, ) -from plaid.utils.deprecation import deprecated logger = logging.getLogger(__name__) - +pa.set_memory_pool(pa.system_memory_pool()) # ------------------------------------------------------------------------------ # HUGGING FACE BRIDGE (with tree flattening and pyarrow tables) @@ -98,7 +98,9 @@ def infer_hf_features_from_value(value: Any) -> Union[Value, Sequence]: dtype, np.int64 ): # very important to satisfy the CGNS standard return Value("int64") - elif np.issubdtype(dtype, np.dtype("|S1")): # pragma: no cover + elif np.issubdtype(dtype, np.dtype("|S1")) or np.issubdtype( + dtype, np.dtype(" tuple[dict[str, Any], list[str], dict[str return hf_sample, all_paths, sample_cgns_types -def _generator_prepare_for_huggingface( +def _hash_value(value): + """Compute a hash for a value (np.ndarray or basic types).""" + if isinstance(value, np.ndarray): + return hashlib.md5(value.view(np.uint8)).hexdigest() + return hashlib.md5(str(value).encode("utf-8")).hexdigest() + + +def process_shard( + shard_ids: list[IndexType], + generator_fn: Callable[[list[list[IndexType]]], Any], + progress: Any, + n_proc: int, +) -> tuple[ + set[str], + dict[str, str], + dict[str, Union[Value, Sequence]], + dict[str, dict[str, Union[str, bool, int]]], + int, +]: + """Process a single shard of sample ids and collect per-shard metadata. + + This function drives a shard-level pass over samples produced by `generator_fn`. + For each sample it: + - flattens the sample into Hugging Face friendly arrays (build_hf_sample), + - collects observed flattened paths, + - aggregates CGNS type metadata, + - infers Hugging Face feature types for each path, + - detects per-path constants using a content hash, + - updates progress (either a multiprocessing.Queue or a tqdm progress bar). + + Args: + shard_ids (list[IndexType]): Sequence of sample ids (a single shard) to process. + generator_fn (Callable): Generator function accepting a list of shard id sequences + and yielding Sample objects for those ids. + progress (Any): Progress reporter; either a multiprocessing.Queue (for parallel + execution) or a tqdm progress bar object (for sequential execution). + n_proc (int): Number of worker processes used by the caller (used to decide + how to report progress). + + Returns: + tuple: + - split_all_paths (set[str]): Set of all flattened feature paths observed in the shard. + - shard_global_cgns_types (dict[str, str]): Mapping path -> CGNS node type observed in the shard. + - shard_global_feature_types (dict[str, Union[Value, Sequence]]): Inferred HF feature types per path. + - split_constant_leaves (dict[str, dict]): Per-path metadata for constant detection. Each entry + is a dict with keys "hash" (str), "constant" (bool) and "count" (int). + - n_samples_processed (int): Number of samples processed in this shard. + + Raises: + ValueError: If inconsistent feature types are detected for the same path within the shard. + """ + split_constant_leaves = {} + split_all_paths = set() + shard_global_cgns_types = {} + shard_global_feature_types = {} + + shards_to_process = [shard_ids] + + for sample in generator_fn(shards_to_process): + hf_sample, all_paths, sample_cgns_types = build_hf_sample(sample) + + split_all_paths.update(hf_sample.keys()) + shard_global_cgns_types.update(sample_cgns_types) + + # Feature type inference + for path in all_paths: + value = hf_sample[path] + if value is None: + continue + inferred = infer_hf_features_from_value(value) + if path not in shard_global_feature_types: + shard_global_feature_types[path] = inferred + elif repr(shard_global_feature_types[path]) != repr(inferred): + raise ValueError( + f"Feature type mismatch for {path} in shard" + ) # pragma: no cover + + # Constant detection using **hash only** + for path, value in hf_sample.items(): + h = _hash_value(value) + if path not in split_constant_leaves: + split_constant_leaves[path] = {"hashes": {h}, "count": 1} + else: + entry = split_constant_leaves[path] + entry["hashes"].add(h) + entry["count"] += 1 + + # Progress + if n_proc > 1: + progress.put(1) # pragma: no cover + else: + progress.update(1) + + return ( + split_all_paths, + shard_global_cgns_types, + shard_global_feature_types, + split_constant_leaves, + len(shard_ids), + ) + + +def preprocess_splits( generators: dict[str, Callable], + gen_kwargs: dict[str, dict[str, list[IndexType]]], + processes_number: int = 1, verbose: bool = True, -) -> tuple[dict[str, dict[str, Any]], dict[str, Any], Features]: - """Inspect PLAID dataset generators and infer Hugging Face feature schema. - - Iterates over all samples in all provided split generators to: - 1. Flatten each CGNS tree into a dictionary of paths → values. - 2. Infer Hugging Face `Features` types for all variable leaves. - 3. Detect constant leaves (values that never change across all samples). - 4. Collect global CGNS type metadata. +) -> tuple[ + dict[str, set[str]], + dict[str, dict[str, Any]], + dict[str, set[str]], + dict[str, str], + dict[str, Union[Value, Sequence]], +]: + """Pre-process dataset splits: inspect samples to infer features, constants and CGNS metadata. + + This function iterates over the provided split generators (optionally in parallel), + flattens each PLAID sample into Hugging Face friendly arrays, detects constant + CGNS leaves (features identical across all samples in a split), infers global + Hugging Face feature types, and aggregates CGNS type metadata. + + The work is sharded per-split and each shard is processed by `process_shard`. + In parallel mode, progress is updated via a multiprocessing.Queue; otherwise a + tqdm progress bar is used. Args: generators (dict[str, Callable]): - Mapping from split names to callables returning sample generators. - Each sample must have `sample.features.data[0.0]` compatible with `flatten_cgns_tree`. - verbose (bool, optional): If True, displays progress bars while processing splits. + Mapping from split name to a generator function. Each generator must + accept a single argument (a sequence of shard ids) and yield PLAID samples. + gen_kwargs (dict[str, dict[str, list[IndexType]]]): + Per-split kwargs used to drive generator invocation (e.g. {"train": {"shards_ids": [...]}}). + processes_number (int, optional): + Number of worker processes to use for shard-level parallelism. Defaults to 1. + verbose (bool, optional): + If True, displays progress bars. Defaults to True. Returns: tuple: - - flat_cst (dict[str, Any]): Mapping from feature path to constant values detected across all splits. - - key_mappings (dict[str, Any]): Metadata dictionary with: - - "variable_features" (list[str]): paths of non-constant features. - - "constant_features" (list[str]): paths of constant features. - - "cgns_types" (dict[str, Any]): CGNS type information for all paths. - - hf_features (datasets.Features): Hugging Face feature specification for variable features. + - split_all_paths (dict[str, set[str]]): + For each split, the set of all observed flattened feature paths (including "_times" keys). + - split_flat_cst (dict[str, dict[str, Any]]): + For each split, a mapping of constant feature path -> value (constant parts of the tree). + - split_var_path (dict[str, set[str]]): + For each split, the set of variable feature paths (non-constant). + - global_cgns_types (dict[str, str]): + Aggregated mapping from flattened path -> CGNS node type. + - global_feature_types (dict[str, Union[Value, Sequence]]): + Aggregated inferred Hugging Face feature types for each variable path. Raises: - ValueError: If inconsistent CGNS types or feature types are found for the same path. + ValueError: If inconsistent feature types or CGNS types are detected across shards/splits. """ - - def values_equal(v1, v2): - if isinstance(v1, np.ndarray) and isinstance(v2, np.ndarray): - return np.array_equal(v1, v2) - return v1 == v2 - global_cgns_types = {} global_feature_types = {} - split_flat_cst = {} split_var_path = {} split_all_paths = {} - # ---- Single pass over all splits and samples ---- - for split_name, generator in generators.items(): - split_constant_leaves = {} + for split_name, generator_fn in generators.items(): + shards_ids_list = gen_kwargs[split_name].get("shards_ids", [None]) + n_proc = max(1, processes_number or len(shards_ids_list)) + + shards_data = [] + + if n_proc == 1: + progress_total = sum(len(shard) for shard in shards_ids_list) + with tqdm( + total=progress_total, + disable=not verbose, + desc=f"Pre-process split {split_name}", + ) as pbar: + for shard_ids in shards_ids_list: + shards_data.append( + process_shard(shard_ids, generator_fn, pbar, n_proc) + ) - split_all_paths[split_name] = set() + else: # pragma: no cover + # Parallel execution + manager = mp.Manager() + progress_queue = manager.Queue() - for sample in tqdm( - generator(), disable=not verbose, desc=f"Process split {split_name}" - ): - # --- Build Hugging Face–compatible sample --- - hf_sample, all_paths, sample_cgns_types = build_hf_sample(sample) + try: + with mp.Pool(n_proc) as pool: + results = [ + pool.apply_async( + process_shard, + args=(shard_ids, generator_fn, progress_queue, n_proc), + ) + for shard_ids in shards_ids_list + ] - split_all_paths[split_name].update(hf_sample.keys()) - # split_all_paths[split_name].update(all_paths) - global_cgns_types.update(sample_cgns_types) + total_samples = sum(len(shard) for shard in shards_ids_list) + with tqdm( + total=total_samples, + disable=not verbose, + desc=f"Pre-process split {split_name}", + ) as pbar: + completed = 0 + while completed < total_samples: + increment = progress_queue.get() + pbar.update(increment) + completed += increment - # --- Infer global HF feature types --- - for path in all_paths: - value = hf_sample[path] - if value is None: - continue + for r in results: + shards_data.append(r.get()) - if isinstance(value, np.ndarray) and value.dtype.type is np.str_: - inferred = Value("string") - else: - inferred = infer_hf_features_from_value(value) + finally: + manager.shutdown() + # Merge shard results + split_all_paths[split_name] = set() + split_constant_hashes = {} + n_samples_total = 0 + + for ( + all_paths, + shard_cgns, + shard_features, + shard_constants, + n_samples, + ) in shards_data: + split_all_paths[split_name].update(all_paths) + global_cgns_types.update(shard_cgns) + + for path, inferred in shard_features.items(): if path not in global_feature_types: global_feature_types[path] = inferred elif repr(global_feature_types[path]) != repr(inferred): @@ -313,32 +470,28 @@ def values_equal(v1, v2): f"Feature type mismatch for {path} in split {split_name}" ) - # --- Update per-split constant detection --- - for path, value in hf_sample.items(): - if path not in split_constant_leaves: - split_constant_leaves[path] = { - "value": value, - "constant": True, - "count": 1, - } + for path, entry in shard_constants.items(): + if path not in split_constant_hashes: + split_constant_hashes[path] = entry else: - entry = split_constant_leaves[path] - entry["count"] += 1 - if entry["constant"] and not values_equal(entry["value"], value): - entry["constant"] = False - - # --- Record per-split constants --- - split_flat_cst[split_name] = dict( - sorted( - ( - (p, e["value"]) - for p, e in split_constant_leaves.items() - if e["constant"] - ), - key=lambda x: x[0], - ) - ) + existing = split_constant_hashes[path] + existing["hashes"].update(entry["hashes"]) + existing["count"] += entry["count"] + + n_samples_total += n_samples + + # Determine truly constant paths (same hash across all samples) + constant_paths = [ + p + for p, entry in split_constant_hashes.items() + if len(entry["hashes"]) == 1 and entry["count"] == n_samples_total + ] + + # Retrieve **values** only for constant paths from first sample + first_sample = next(generator_fn([shards_ids_list[0]])) + hf_sample, _, _ = build_hf_sample(first_sample) + split_flat_cst[split_name] = {p: hf_sample[p] for p in sorted(constant_paths)} split_var_path[split_name] = { p for p in split_all_paths[split_name] @@ -348,46 +501,61 @@ def values_equal(v1, v2): global_feature_types = { p: global_feature_types[p] for p in sorted(global_feature_types) } - var_features = sorted(list(set().union(*split_var_path.values()))) - if len(var_features) == 0: - raise ValueError( # pragma: no cover + return ( + split_all_paths, + split_flat_cst, + split_var_path, + global_cgns_types, + global_feature_types, + ) + + +def _generator_prepare_for_huggingface( + generators: dict[str, Callable], + gen_kwargs: dict, + processes_number: int = 1, + verbose: bool = True, +): + ( + split_all_paths, + split_flat_cst, + split_var_path, + global_cgns_types, + global_feature_types, + ) = preprocess_splits(generators, gen_kwargs, processes_number, verbose) + + # --- build HF features --- + var_features = sorted(list(set().union(*split_var_path.values()))) + if len(var_features) == 0: # pragma: no cover + raise ValueError( "no variable feature found, is your dataset variable through samples?" ) - # --------------------------------------------------- - # for test-like splits, some var_features are all None (e.g.: outputs): need to add '_times' counterparts to corresponding constant trees for split_name in split_flat_cst.keys(): for path in var_features: if not path.endswith("_times") and path not in split_all_paths[split_name]: split_flat_cst[split_name][path + "_times"] = None # pragma: no cover - if ( - path in split_flat_cst[split_name] - ): # remove for flat_cst the path that will be forcely included in the arrow tables + if path in split_flat_cst[split_name]: split_flat_cst[split_name].pop(path) # pragma: no cover - # ---- Constant features sanity check cst_features = { split_name: sorted(list(cst.keys())) for split_name, cst in split_flat_cst.items() } - first_split, first_value = next(iter(cst_features.items()), (None, None)) for split, value in cst_features.items(): assert value == first_value, ( - f"cst_features differ for split '{split}' (vs '{first_split}'): something went wrong in _generator_prepare_for_huggingface." + f"cst_features differ for split '{split}' (vs '{first_split}')" ) - cst_features = first_value - # ---- Build global HF Features (only variable) ---- hf_features_map = {} for k in var_features: if k.endswith("_times"): hf_features_map[k] = Sequence(Value("float64")) # pragma: no cover else: hf_features_map[k] = global_feature_types[k] - hf_features = Features(hf_features_map) var_features = [path for path in var_features if not path.endswith("_times")] @@ -402,6 +570,192 @@ def values_equal(v1, v2): return split_flat_cst, key_mappings, hf_features +# # ------------------------------------------- +# # --------- Sequential version +# def _generator_prepare_for_huggingface( +# generators: dict[str, Callable], +# gen_kwargs: dict, +# processes_number: int = 1, +# verbose: bool = True, +# ) -> tuple[dict[str, dict[str, Any]], dict[str, Any], Features]: +# """Inspect PLAID dataset generators and infer Hugging Face feature schema. + +# Iterates over all samples in all provided split generators to: +# 1. Flatten each CGNS tree into a dictionary of paths → values. +# 2. Infer Hugging Face `Features` types for all variable leaves. +# 3. Detect constant leaves (values that never change across all samples). +# 4. Collect global CGNS type metadata. + +# Args: +# generators (dict[str, Callable]): +# Mapping from split names to callables returning sample generators. +# Each sample must have `sample.features.data[0.0]` compatible with `flatten_cgns_tree`. +# gen_kwargs (dict, optional, default=None): +# Optional mapping from split names to dictionaries of keyword arguments +# to be passed to each generator function, used for parallelization. +# processes_number (int, optional): Number of parallel processes to use. +# verbose (bool, optional): If True, displays progress bars while processing splits. + +# Returns: +# tuple: +# - flat_cst (dict[str, Any]): Mapping from feature path to constant values detected across all splits. +# - key_mappings (dict[str, Any]): Metadata dictionary with: +# - "variable_features" (list[str]): paths of non-constant features. +# - "constant_features" (list[str]): paths of constant features. +# - "cgns_types" (dict[str, Any]): CGNS type information for all paths. +# - hf_features (datasets.Features): Hugging Face feature specification for variable features. + +# Raises: +# ValueError: If inconsistent CGNS types or feature types are found for the same path. +# """ +# processes_number + +# def values_equal(v1, v2): +# if isinstance(v1, np.ndarray) and isinstance(v2, np.ndarray): +# return np.array_equal(v1, v2) +# return v1 == v2 + +# global_cgns_types = {} +# global_feature_types = {} + +# split_flat_cst = {} +# split_var_path = {} +# split_all_paths = {} + +# # ---- Single pass over all splits and samples ---- +# for split_name, generator in generators.items(): +# split_constant_leaves = {} + +# split_all_paths[split_name] = set() + +# n_samples = 0 +# for sample in tqdm( +# generator(**gen_kwargs[split_name]), +# disable=not verbose, +# desc=f"Pre-process split {split_name}", +# ): +# # --- Build Hugging Face–compatible sample --- +# hf_sample, all_paths, sample_cgns_types = build_hf_sample(sample) + +# split_all_paths[split_name].update(hf_sample.keys()) +# # split_all_paths[split_name].update(all_paths) +# global_cgns_types.update(sample_cgns_types) + +# # --- Infer global HF feature types --- +# for path in all_paths: +# value = hf_sample[path] +# if value is None: +# continue + +# # if isinstance(value, np.ndarray) and value.dtype.type is np.str_: +# # inferred = Value("string") +# # else: +# # inferred = infer_hf_features_from_value(value) + +# inferred = infer_hf_features_from_value(value) + +# if path not in global_feature_types: +# global_feature_types[path] = inferred +# elif repr(global_feature_types[path]) != repr(inferred): +# raise ValueError( # pragma: no cover +# f"Feature type mismatch for {path} in split {split_name}" +# ) + +# # --- Update per-split constant detection --- +# for path, value in hf_sample.items(): +# if path not in split_constant_leaves: +# split_constant_leaves[path] = { +# "value": value, +# "constant": True, +# "count": 1, +# } +# else: +# entry = split_constant_leaves[path] +# entry["count"] += 1 +# if entry["constant"] and not values_equal(entry["value"], value): +# entry["constant"] = False + +# n_samples += 1 + +# # --- Record per-split constants --- +# for p, e in split_constant_leaves.items(): +# if e["count"] < n_samples: +# split_constant_leaves[p]["constant"] = False + +# split_flat_cst[split_name] = dict( +# sorted( +# ( +# (p, e["value"]) +# for p, e in split_constant_leaves.items() +# if e["constant"] +# ), +# key=lambda x: x[0], +# ) +# ) + +# split_var_path[split_name] = { +# p +# for p in split_all_paths[split_name] +# if p not in split_flat_cst[split_name] +# } + +# global_feature_types = { +# p: global_feature_types[p] for p in sorted(global_feature_types) +# } +# var_features = sorted(list(set().union(*split_var_path.values()))) + +# if len(var_features) == 0: +# raise ValueError( # pragma: no cover +# "no variable feature found, is your dataset variable through samples?" +# ) + +# # --------------------------------------------------- +# # for test-like splits, some var_features are all None (e.g.: outputs): need to add '_times' counterparts to corresponding constant trees +# for split_name in split_flat_cst.keys(): +# for path in var_features: +# if not path.endswith("_times") and path not in split_all_paths[split_name]: +# split_flat_cst[split_name][path + "_times"] = None # pragma: no cover +# if ( +# path in split_flat_cst[split_name] +# ): # remove for flat_cst the path that will be forcely included in the arrow tables +# split_flat_cst[split_name].pop(path) # pragma: no cover + +# # ---- Constant features sanity check +# cst_features = { +# split_name: sorted(list(cst.keys())) +# for split_name, cst in split_flat_cst.items() +# } + +# first_split, first_value = next(iter(cst_features.items()), (None, None)) +# for split, value in cst_features.items(): +# assert value == first_value, ( +# f"cst_features differ for split '{split}' (vs '{first_split}'): something went wrong in _generator_prepare_for_huggingface." +# ) + +# cst_features = first_value + +# # ---- Build global HF Features (only variable) ---- +# hf_features_map = {} +# for k in var_features: +# if k.endswith("_times"): +# hf_features_map[k] = Sequence(Value("float64")) # pragma: no cover +# else: +# hf_features_map[k] = global_feature_types[k] + +# hf_features = Features(hf_features_map) + +# var_features = [path for path in var_features if not path.endswith("_times")] +# cst_features = [path for path in cst_features if not path.endswith("_times")] + +# key_mappings = { +# "variable_features": var_features, +# "constant_features": cst_features, +# "cgns_types": global_cgns_types, +# } + +# return split_flat_cst, key_mappings, hf_features + + def to_plaid_dataset( hf_dataset: datasets.Dataset, flat_cst: dict[str, Any], @@ -484,6 +838,8 @@ def to_plaid_sample( else: if isinstance(value, pa.ListArray): row[name] = np.stack(value.to_numpy(zero_copy_only=False)) + elif isinstance(value, pa.StringArray): # pragma: no cover + row[name] = value.to_numpy(zero_copy_only=False) else: row[name] = value.to_numpy(zero_copy_only=True) @@ -595,22 +951,27 @@ def plaid_dataset_to_huggingface_datasetdict( }) """ - def generator(dataset): - for sample in dataset: - yield sample + def generator_(shards_ids): + for ids in shards_ids: + if isinstance(ids, int): + ids = [ids] # pragma: no cover + for id in ids: + yield dataset[id] + + generators = {split_name: generator_ for split_name in main_splits.keys()} - generators = { - split_name: partial(generator, dataset[ids]) - for split_name, ids in main_splits.items() + gen_kwargs = { + split_name: {"shards_ids": [ids]} for split_name, ids in main_splits.items() } return plaid_generator_to_huggingface_datasetdict( - generators, processes_number, writer_batch_size, verbose + generators, gen_kwargs, processes_number, writer_batch_size, verbose ) def plaid_generator_to_huggingface_datasetdict( generators: dict[str, Callable], + gen_kwargs: dict[str, dict[str, list[IndexType]]], processes_number: int = 1, writer_batch_size: int = 1, verbose: bool = False, @@ -632,6 +993,9 @@ def plaid_generator_to_huggingface_datasetdict( the dataset from the generators. writer_batch_size (int, optional, default=1): Batch size used when writing samples to disk in Hugging Face format. + gen_kwargs (dict, optional, default=None): + Optional mapping from split names to dictionaries of keyword arguments + to be passed to each generator function, used for parallelization. verbose (bool, optional, default=False): If True, displays progress bars and diagnostic messages. @@ -670,21 +1034,22 @@ def plaid_generator_to_huggingface_datasetdict( ['Zone1/FlowSolution/VelocityX', 'Zone1/FlowSolution/VelocityY', ...] """ flat_cst, key_mappings, hf_features = _generator_prepare_for_huggingface( - generators, verbose + generators, gen_kwargs, processes_number, verbose ) all_features_keys = list(hf_features.keys()) - def generator_fn(gen_func, all_features_keys): - for sample in gen_func(): + def generator_fn(gen_func, all_features_keys, **kwargs): + for sample in gen_func(**kwargs): hf_sample, _, _ = build_hf_sample(sample) yield {path: hf_sample.get(path, None) for path in all_features_keys} _dict = {} for split_name, gen_func in generators.items(): - gen = partial(generator_fn, gen_func, all_features_keys) + gen = partial(generator_fn, all_features_keys=all_features_keys) _dict[split_name] = datasets.Dataset.from_generator( generator=gen, + gen_kwargs={"gen_func": gen_func, **gen_kwargs[split_name]}, features=hf_features, num_proc=processes_number, writer_batch_size=writer_batch_size, @@ -699,6 +1064,26 @@ def generator_fn(gen_func, all_features_keys): # ------------------------------------------------------------------------------ +def _compute_num_shards(hf_dataset_dict: datasets.DatasetDict) -> dict[str, int]: + target_shard_size_mb = 500 + + num_shards = {} + for split_name, ds in hf_dataset_dict.items(): + n_samples = len(ds) + assert n_samples > 0, f"split {split_name} has no sample" + + dataset_size_bytes = ds.data.nbytes + target_shard_size_bytes = target_shard_size_mb * 1024 * 1024 + + n_shards = max( + 1, + (dataset_size_bytes + target_shard_size_bytes - 1) + // target_shard_size_bytes, + ) + num_shards[split_name] = min(n_samples, int(n_shards)) + return num_shards + + def instantiate_plaid_datasetdict_from_hub( repo_id: str, enforce_shapes: bool = True, @@ -915,7 +1300,7 @@ def load_tree_struct_from_hub( def push_dataset_dict_to_hub( - repo_id: str, hf_dataset_dict: datasets.DatasetDict, *args, **kwargs + repo_id: str, hf_dataset_dict: datasets.DatasetDict, **kwargs ) -> None: # pragma: no cover (not tested in unit tests) """Push a Hugging Face `DatasetDict` to the Hugging Face Hub. @@ -923,15 +1308,19 @@ def push_dataset_dict_to_hub( you to upload a dataset dictionary (with one or more splits such as `"train"`, `"validation"`, `"test"`) to the Hugging Face Hub. + Note: + The function automatically handles sharding of the dataset by setting `num_shards` + for each split. For each split, the number of shards is set to the minimum between + the number of samples in that split and such that shards are targetted to approx. 500 MB. + This ensures efficient chunking while preventing excessive fragmentation. Empty splits + will raise an assertion error. + Args: repo_id (str): The repository ID on the Hugging Face Hub (e.g. `"username/dataset_name"`). hf_dataset_dict (datasets.DatasetDict): The Hugging Face dataset dictionary to push. - *args: - Positional arguments forwarded to - [`DatasetDict.push_to_hub`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.DatasetDict.push_to_hub). **kwargs: Keyword arguments forwarded to [`DatasetDict.push_to_hub`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.DatasetDict.push_to_hub). @@ -939,7 +1328,20 @@ def push_dataset_dict_to_hub( Returns: None """ - hf_dataset_dict.push_to_hub(repo_id, *args, **kwargs) + num_shards = _compute_num_shards(hf_dataset_dict) + num_proc = kwargs.get("num_proc", None) + if num_proc is not None: # pragma: no cover + min_num_shards = min(num_shards.values()) + if min_num_shards < num_proc: + logger.warning( + f"num_proc chaged from {num_proc} to 1 to safely adapt for num_shards={num_shards}" + ) + num_proc = 1 + del kwargs["num_proc"] + + hf_dataset_dict.push_to_hub( + repo_id, num_shards=num_shards, num_proc=num_proc, **kwargs + ) def push_infos_to_hub( @@ -1145,7 +1547,7 @@ def load_tree_struct_from_disk( def save_dataset_dict_to_disk( - path: Union[str, Path], hf_dataset_dict: datasets.DatasetDict, *args, **kwargs + path: Union[str, Path], hf_dataset_dict: datasets.DatasetDict, **kwargs ) -> None: """Save a Hugging Face DatasetDict to disk. @@ -1155,9 +1557,6 @@ def save_dataset_dict_to_disk( Args: path (Union[str, Path]): Directory path where the DatasetDict will be saved. hf_dataset_dict (datasets.DatasetDict): The Hugging Face DatasetDict to save. - *args: - Positional arguments forwarded to - [`DatasetDict.save_to_disk`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.DatasetDict.save_to_disk). **kwargs: Keyword arguments forwarded to [`DatasetDict.save_to_disk`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.DatasetDict.save_to_disk). @@ -1165,7 +1564,20 @@ def save_dataset_dict_to_disk( Returns: None """ - hf_dataset_dict.save_to_disk(str(path), *args, **kwargs) + num_shards = _compute_num_shards(hf_dataset_dict) + num_proc = kwargs.get("num_proc", None) + if num_proc is not None: # pragma: no cover + min_num_shards = min(num_shards.values()) + if min_num_shards < num_proc: + logger.warning( + f"num_proc chaged from {num_proc} to 1 to safely adapt for num_shards={num_shards}" + ) + num_proc = 1 + del kwargs["num_proc"] + + hf_dataset_dict.save_to_disk( + str(path), num_shards=num_shards, num_proc=num_proc, **kwargs + ) def save_infos_to_disk( @@ -1568,7 +1980,7 @@ def huggingface_description_to_problem_definition( try: func(description[key]) except KeyError: - logger.info(f"Could not retrieve key:'{key}' from description") + logger.error(f"Could not retrieve key:'{key}' from description") pass return problem_definition @@ -1596,195 +2008,141 @@ def huggingface_description_to_infos( return infos -@deprecated( - "will be removed (this hf format will not be not maintained)", - version="0.1.9", - removal="0.2.0", -) -def create_string_for_huggingface_dataset_card( - description: dict, - download_size_bytes: int, - dataset_size_bytes: int, - nb_samples: int, - owner: str, - license: str, - zenodo_url: Optional[str] = None, - arxiv_paper_url: Optional[str] = None, +def update_dataset_card( + dataset_card: str, + infos: dict[str, dict[str, str]] = None, pretty_name: Optional[str] = None, - size_categories: Optional[list[str]] = None, - task_categories: Optional[list[str]] = None, - tags: Optional[list[str]] = None, dataset_long_description: Optional[str] = None, - url_illustration: Optional[str] = None, + illustration_urls: Optional[list[str]] = None, + arxiv_paper_urls: Optional[list[str]] = None, ) -> str: - """Use this function for creating a dataset card, to upload together with the datase on the Hugging Face hub. - - Doing so ensure that load_dataset from the hub will populate the hf-dataset.description field, and be compatible for conversion to plaid. - - Without a dataset_card, the description field is lost. - - The parameters download_size_bytes and dataset_size_bytes can be determined after a - dataset has been uploaded on Hugging Face: - - manually by reading their values on the dataset page README.md, - - automatically as shown in the example below - - See `the hugginface examples `__ for a concrete use. + r"""Update a dataset card with PLAID-specific metadata and documentation. Args: - description (dict): Hugging Face dataset description. Obtained from - - description = hf_dataset.description - - description = generate_huggingface_description(infos, problem_definition) - download_size_bytes (int): the size of the dataset when downloaded from the hub - dataset_size_bytes (int): the size of the dataset when loaded in RAM - nb_samples (int): the number of samples in the dataset - owner (str): the owner of the dataset, usually a username or organization name on Hugging Face - license (str): the license of the dataset, e.g. "CC-BY-4.0", "CC0-1.0", etc. - zenodo_url (str, optional): the Zenodo URL of the dataset, if available - arxiv_paper_url (str, optional): the arxiv paper URL of the dataset, if available - pretty_name (str, optional): a human-readable name for the dataset, e.g. "PLAID Dataset" - size_categories (list[str], optional): size categories of the dataset, e.g. ["small", "medium", "large"] - task_categories (list[str], optional): task categories of the dataset, e.g. ["image-classification", "text-generation"] - tags (list[str], optional): tags for the dataset, e.g. ["3D", "simulation", "mesh"] - dataset_long_description (str, optional): a long description of the dataset, providing more details about its content and purpose - url_illustration (str, optional): a URL to an illustration image for the dataset, e.g. a screenshot or a sample mesh + dataset_card (str): The original dataset card content to update. + infos (dict[str, dict[str, str]]): Dictionary containing dataset information + with "legal" and "data_production" sections. Defaults to None. + pretty_name (str, optional): A human-readable name for the dataset. Defaults to None. + dataset_long_description (str, optional): Detailed description of the dataset's content, + purpose, and characteristics. Defaults to None. + illustration_urls (list[str], optional): List of URLs to images illustrating the dataset. + Defaults to None. + arxiv_paper_urls (list[str], optional): List of URLs to related arXiv papers. + Defaults to None. Returns: - dataset (Dataset): the converted dataset - problem_definition (ProblemDefinition): the problem definition generated from the Hugging Face dataset + str: The updated dataset card content as a string. Example: - .. code-block:: python - - hf_dataset.push_to_hub("chanel/dataset") - - from datasets import load_dataset_builder - - datasetInfo = load_dataset_builder("chanel/dataset").__getstate__()['info'] - - from huggingface_hub import DatasetCard + ```python + # Create initial dataset card + card = "---\ndataset_name: my_dataset\n---" + + # Update with PLAID-specific content + updated_card = update_dataset_card( + dataset_card=card, + license="mit", + pretty_name="My PLAID Dataset", + dataset_long_description="This dataset contains...", + illustration_urls=["https://example.com/image.png"], + arxiv_paper_urls=["https://arxiv.org/abs/..."] + ) - card_text = create_string_for_huggingface_dataset_card( - description = description, - download_size_bytes = datasetInfo.download_size, - dataset_size_bytes = datasetInfo.dataset_size, - ...) - dataset_card = DatasetCard(card_text) - dataset_card.push_to_hub("chanel/dataset") + # Push to Hugging Face Hub + from huggingface_hub import DatasetCard + dataset_card = DatasetCard(updated_card) + dataset_card.push_to_hub("username/dataset") + ``` """ - str__ = f"""--- -license: {license} -""" + lines = dataset_card.splitlines() + lines = [s for s in lines if not s.startswith("license")] - if size_categories: - str__ += f"""size_categories: - {size_categories} -""" - - if task_categories: - str__ += f"""task_categories: - {task_categories} -""" + indices = [i for i, line in enumerate(lines) if line.strip() == "---"] + assert len(indices) >= 2, ( + "Cannot find two instances of '---', you should try to update a correct dataset_card." + ) + lines = lines[: indices[1] + 1] + + count = 1 + lines.insert(count, f"license: {infos['legal']['license']}") + count += 1 + lines.insert(count, "task_categories:") + count += 1 + lines.insert(count, "- graph-ml") + count += 1 if pretty_name: - str__ += f"""pretty_name: {pretty_name} -""" - - if tags: - str__ += f"""tags: - {tags} -""" - - str__ += f"""configs: - - config_name: default - data_files: - - split: all_samples - path: data/all_samples-* -dataset_info: - description: {description} - features: - - name: sample - dtype: binary - splits: - - name: all_samples - num_bytes: {dataset_size_bytes} - num_examples: {nb_samples} - download_size: {download_size_bytes} - dataset_size: {dataset_size_bytes} ---- - -# Dataset Card -""" - if url_illustration: - str__ += f"""![image/png]({url_illustration}) - -This dataset contains a single Hugging Face split, named 'all_samples'. - -The samples contains a single Hugging Face feature, named called "sample". - -Samples are instances of [plaid.containers.sample.Sample](https://plaid-lib.readthedocs.io/en/latest/autoapi/plaid/containers/sample/index.html#plaid.containers.sample.Sample). -Mesh objects included in samples follow the [CGNS](https://cgns.github.io/) standard, and can be converted in -[Muscat.Containers.Mesh.Mesh](https://muscat.readthedocs.io/en/latest/_source/Muscat.Containers.Mesh.html#Muscat.Containers.Mesh.Mesh). - + lines.insert(count, f"pretty_name: {pretty_name}") + count += 1 + lines.insert(count, "tags:") + count += 1 + lines.insert(count, "- physics learning") + count += 1 + lines.insert(count, "- geometry learning") + count += 1 + + str__ = "\n".join(lines) + "\n" + + if illustration_urls: + str__ += "

\n" + for url in illustration_urls: + str__ += f"{url}\n" + str__ += "

\n\n" + + if infos: + str__ += ( + f"```yaml\n{yaml.dump(infos, sort_keys=False, allow_unicode=True)}\n```" + ) + str__ += """ Example of commands: ```python -import pickle from datasets import load_dataset -from plaid import Sample +from plaid.bridges import huggingface_bridge -# Load the dataset -dataset = load_dataset("chanel/dataset", split="all_samples") +repo_id = "chanel/dataset" +pb_def_name = "pb_def_name" #`pb_def_name` is to choose from the repo `problem_definitions` folder -# Get the first sample of the first split -split_names = list(dataset.description["split"].keys()) -ids_split_0 = dataset.description["split"][split_names[0]] -sample_0_split_0 = dataset[ids_split_0[0]]["sample"] -plaid_sample = Sample.model_validate(pickle.loads(sample_0_split_0)) -print("type(plaid_sample) =", type(plaid_sample)) - -print("plaid_sample =", plaid_sample) +# Load the dataset +hf_datasetdict = load_dataset(repo_id) -# Get a field from the sample -field_names = plaid_sample.get_field_names() -field = plaid_sample.get_field(field_names[0]) -print("field_names[0] =", field_names[0]) +# Load addition required data +flat_cst, key_mappings = huggingface_bridge.load_tree_struct_from_hub(repo_id) +pb_def = huggingface_bridge.load_problem_definition_from_hub(repo_id, pb_def_name) -print("field.shape =", field.shape) +# Efficient reconstruction of plaid samples +for split_name, hf_dataset in hf_datasetdict.items(): + for i in range(len(hf_dataset)): + sample = huggingface_bridge.to_plaid_sample( + hf_dataset, + i, + flat_cst[split_name], + key_mappings["cgns_types"], + ) -# Get the mesh and convert it to Muscat -from Muscat.Bridges import CGNSBridge -CGNS_tree = plaid_sample.get_mesh() -mesh = CGNSBridge.CGNSToMesh(CGNS_tree) -print(mesh) +# Extract input and output features from samples: +for t in sample.get_all_mesh_times(): + for path in pb_def.get_in_features_identifiers(): + sample.get_feature_by_path(path=path, time=t) + for path in pb_def.get_out_features_identifiers(): + sample.get_feature_by_path(path=path, time=t) ``` - -## Dataset Details - -### Dataset Description - """ + str__ += "This dataset was generated in [PLAID](https://plaid-lib.readthedocs.io/), we refer to this documentation for additional details on how to extract data from `sample` objects.\n" if dataset_long_description: - str__ += f"""{dataset_long_description} -""" - - str__ += f"""- **Language:** [PLAID](https://plaid-lib.readthedocs.io/) -- **License:** {license} -- **Owner:** {owner} + str__ += f""" +### Dataset Description +{dataset_long_description} """ - if zenodo_url or arxiv_paper_url: + if arxiv_paper_urls: str__ += """ ### Dataset Sources +- **Papers:** """ - - if zenodo_url: - str__ += f"""- **Repository:** [Zenodo]({zenodo_url}) -""" - - if arxiv_paper_url: - str__ += f"""- **Paper:** [arxiv]({arxiv_paper_url}) -""" + for url in arxiv_paper_urls: + str__ += f" - [arxiv]({url})\n" return str__ diff --git a/src/plaid/containers/dataset.py b/src/plaid/containers/dataset.py index 45f80f2a..c0dfc2bb 100644 --- a/src/plaid/containers/dataset.py +++ b/src/plaid/containers/dataset.py @@ -963,8 +963,8 @@ def set_infos(self, infos: dict[str, dict[str, str]]) -> None: f"{info_key=} not among authorized keys. Maybe you want to try among these keys {AUTHORIZED_INFO_KEYS[cat_key]}" ) - if len(self._infos) > 0: - logger.warning("infos not empty, replacing it anyway") + # if len(self._infos) > 0: + # logger.warning("infos not empty, replacing it anyway") self._infos = copy.deepcopy(infos) if "plaid" not in self._infos: diff --git a/src/plaid/containers/features.py b/src/plaid/containers/features.py index cadf5770..484497fb 100644 --- a/src/plaid/containers/features.py +++ b/src/plaid/containers/features.py @@ -914,6 +914,11 @@ def add_global( else: base_node = self.init_base(1, 1, "Global", time) + if isinstance(global_array, str): # pragma: no cover + global_array = np.frombuffer( + global_array.encode("ascii"), dtype="S1", count=len(global_array) + ) + if CGU.getValueByPath(base_node, name) is None: CGL.newDataArray(base_node, name, value=global_array) else: diff --git a/tests/bridges/test_huggingface_bridge.py b/tests/bridges/test_huggingface_bridge.py index 452aca6d..aab5b0dc 100644 --- a/tests/bridges/test_huggingface_bridge.py +++ b/tests/bridges/test_huggingface_bridge.py @@ -59,15 +59,29 @@ def generator_(): @pytest.fixture() -def generator_split(dataset, problem_definition) -> dict[str, Callable]: - generators_ = {} +def gen_kwargs(problem_definition) -> dict[str, dict]: + gen_kwargs = {} for split_name, ids in problem_definition.get_split().items(): + mid = len(ids) // 2 + gen_kwargs[split_name] = {"shards_ids": [ids[:mid], ids[mid:]]} + return gen_kwargs - def generator_(ids=ids): - for id in ids: - yield dataset[id] + +@pytest.fixture() +def generator_split(dataset, gen_kwargs) -> dict[str, Callable]: + generators_ = {} + + for split_name in gen_kwargs.keys(): + + def generator_(shards_ids): + for ids in shards_ids: + if isinstance(ids, int): + ids = [ids] + for id in ids: + yield dataset[id] generators_[split_name] = generator_ + return generators_ @@ -154,10 +168,10 @@ def test_with_datasetdict(self, dataset, problem_definition): dataset[0].get_mesh(), dataset[0].get_mesh() ) - def test_with_generator(self, generator_split): + def test_with_generator(self, generator_split, gen_kwargs): hf_dataset_dict, flat_cst, key_mappings = ( huggingface_bridge.plaid_generator_to_huggingface_datasetdict( - generator_split + generator_split, gen_kwargs ) ) huggingface_bridge.to_plaid_sample( @@ -176,11 +190,11 @@ def test_with_generator(self, generator_split): # ------------------------------------------------------------------------------ def test_save_load_to_disk( - self, current_directory, generator_split, infos, problem_definition + self, current_directory, generator_split, infos, problem_definition, gen_kwargs ): hf_dataset_dict, flat_cst, key_mappings = ( huggingface_bridge.plaid_generator_to_huggingface_datasetdict( - generator_split + generator_split, gen_kwargs ) ) @@ -316,20 +330,14 @@ def test_huggingface_description_to_infos(self, infos): huggingface_bridge.huggingface_description_to_infos(hf_description) # ---- Deprecated ---- - def test_create_string_for_huggingface_dataset_card(self, hf_dataset): - huggingface_bridge.create_string_for_huggingface_dataset_card( - description=hf_dataset.description, - download_size_bytes=10, - dataset_size_bytes=10, - nb_samples=10, - owner="Safran", - license="cc-by-sa-4.0", - zenodo_url="https://zenodo.org/records/10124594", - arxiv_paper_url="https://arxiv.org/pdf/2305.12871", + def test_create_string_for_huggingface_dataset_card(self, infos): + dataset_card = "---\ndataset_name: my_dataset\n---" + + huggingface_bridge.update_dataset_card( + dataset_card=dataset_card, + infos=infos, pretty_name="2D quasistatic non-linear structural mechanics solutions", - size_categories=["n<1K"], - task_categories=["graph-ml"], - tags=["physics learning", "geometry learning"], dataset_long_description="my long description", - url_illustration="url3", + illustration_urls=["url0", "url1"], + arxiv_paper_urls=["url2"], )