Source code for merlin_standard_lib.utils.misc_utils

#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import glob
import inspect
import itertools
import logging
import os
import sys
import time
from typing import Any, Dict

logger = logging.getLogger(__name__)


[docs]def filter_kwargs(kwargs, thing_with_kwargs, filter_positional_or_keyword=True): sig = inspect.signature(thing_with_kwargs) if filter_positional_or_keyword: filter_keys = [ param.name for param in sig.parameters.values() if param.kind == param.POSITIONAL_OR_KEYWORD ] else: filter_keys = [param.name for param in sig.parameters.values()] filtered_dict = { filter_key: kwargs[filter_key] for filter_key in filter_keys if filter_key in kwargs } return filtered_dict
[docs]def safe_json(data): if data is None: return True elif isinstance(data, (bool, int, float, str)): return True elif isinstance(data, (tuple, list)): return all(safe_json(x) for x in data) elif isinstance(data, dict): return all(isinstance(k, str) and safe_json(v) for k, v in data.items()) return False
[docs]def get_filenames(data_paths, files_filter_pattern="*"): paths = [glob.glob(os.path.join(path, files_filter_pattern)) for path in data_paths] return list(itertools.chain.from_iterable(paths))
[docs]def get_label_feature_name(feature_map: Dict[str, Any]) -> str: """ Analyses the feature map config and returns the name of the label feature (e.g. item_id) """ label_feature_config = list( k for k, v in feature_map.items() if "is_label" in v and v["is_label"] ) if len(label_feature_config) == 0: raise ValueError("One feature have be configured as label (is_label = True)") if len(label_feature_config) > 1: raise ValueError("Only one feature can be selected as label (is_label = True)") label_name = label_feature_config[0] return label_name
[docs]def get_timestamp_feature_name(feature_map: Dict[str, Any]) -> str: """ Analyses the feature map config and returns the name of the label feature (e.g. item_id) """ timestamp_feature_name = list(k for k, v in feature_map.items() if v["dtype"] == "timestamp") if len(timestamp_feature_name) == 0: raise Exception('No feature have be configured as timestamp (dtype = "timestamp")') if len(timestamp_feature_name) > 1: raise Exception('Only one feature can be configured as timestamp (dtype = "timestamp")') timestamp_fname = timestamp_feature_name[0] return timestamp_fname
[docs]def get_parquet_files_names(data_args, time_indices, is_train, eval_on_test_set=False): if not isinstance(time_indices, list): time_indices = [time_indices] time_window_folders = [ os.path.join( data_args.data_path, str(time_index).zfill(data_args.time_window_folder_pad_digits), ) for time_index in time_indices ] if is_train: data_filename = "train.parquet" else: if eval_on_test_set: data_filename = "test.parquet" else: data_filename = "valid.parquet" parquet_paths = [os.path.join(folder, data_filename) for folder in time_window_folders] # If paths are folders, list the parquet file within the folders # parquet_paths = get_filenames(parquet_paths, files_filter_pattern="*.parquet" return parquet_paths
[docs]class Timing: """A context manager that prints the execution time of the block it manages""" def __init__(self, message, file=sys.stdout, logger=None, one_line=True): self.message = message if logger is not None: self.default_logger = False self.one_line = False self.logger = logger else: self.default_logger = True self.one_line = one_line self.logger = None self.file = file def _log(self, message, newline=True): # pylint: disable=broad-except if self.default_logger: print(message, end="\n" if newline else "", file=self.file) try: self.file.flush() except Exception: pass else: self.logger.info(message) def __enter__(self): self.start = time.time() self._log(self.message, not self.one_line) def __exit__(self, exc_type, exc_value, traceback): self._log( "{}Done in {:.3f}s".format( "" if self.one_line else self.message, time.time() - self.start ) )
[docs]def get_object_size(obj, seen=None): """Recursively finds size of objects""" size = sys.getsizeof(obj) if seen is None: seen = set() obj_id = id(obj) if obj_id in seen: return 0 # Important mark as seen *before* entering recursion to gracefully handle # self-referential objects seen.add(obj_id) if isinstance(obj, dict): size += sum([get_object_size(v, seen) for v in obj.values()]) size += sum([get_object_size(k, seen) for k in obj.keys()]) elif hasattr(obj, "__dict__"): size += get_object_size(obj.__dict__, seen) elif hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes, bytearray)): size += sum([get_object_size(i, seen) for i in obj]) return size
[docs]def validate_dataset(paths_or_dataset, batch_size, buffer_size, engine, reader_kwargs): """ Util function to load NVTabular Dataset from disk Parameters ---------- paths_or_dataset: Union[nvtabular.Dataset, str] Path to dataset to load of nvtabular Dataset, if Dataset, return the object. batch_size: int batch size for Dataloader. buffer_size: float parameter, which refers to the fraction of batches to load at once. engine: str parameter to specify the file format, possible values are: ["parquet", "csv", "csv-no-header"]. reader_kwargs: dict Additional arguments of the specified reader. """ try: from nvtabular.io import Dataset except ImportError: raise ValueError("NVTabular is necessary for this function, please install: " "nvtabular.") # TODO: put this in parent class and allow # torch dataset to leverage as well? # if a dataset was passed, just return it if isinstance(paths_or_dataset, Dataset): return paths_or_dataset # otherwise initialize a dataset # from paths or glob pattern if isinstance(paths_or_dataset, str): files = glob.glob(paths_or_dataset) _is_empty_msg = "Couldn't find file pattern {} in directory {}".format( *os.path.split(paths_or_dataset) ) else: # TODO: some checking around attribute # error here? files = list(paths_or_dataset) _is_empty_msg = "paths_or_dataset list must contain at least one filename" assert isinstance(files, list) if len(files) == 0: raise ValueError(_is_empty_msg) # implement buffer size logic # TODO: IMPORTANT # should we divide everything by 3 to account # for extra copies laying around due to asynchronicity? reader_kwargs = reader_kwargs or {} if buffer_size >= 1: if buffer_size < batch_size: reader_kwargs["batch_size"] = int(batch_size * buffer_size) else: reader_kwargs["batch_size"] = buffer_size else: reader_kwargs["part_mem_fraction"] = buffer_size return Dataset(files, engine=engine, **reader_kwargs)