Source code for nvtabular.ops.reduce_dtype_size

# Copyright (c) 2021, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import dask.dataframe as dd
import numpy as np

from merlin.core.dispatch import DataFrameType, annotate
from merlin.dag.ops.stat_operator import StatOperator
from merlin.schema import Schema
from nvtabular.ops.operator import ColumnSelector, Operator

_INT_DTYPES = [np.int8, np.int16, np.int32, np.int64]

[docs]class ReduceDtypeSize(StatOperator): """ ReduceDtypeSize changes the dtypes of numeric columns. For integer columns this will choose a dtype such that the minimum and maximum values in the column will fit. For float columns this will cast to a float32. """
[docs] def __init__(self, float_dtype=np.float32): super().__init__() self.float_dtype = float_dtype self.ranges = {} self.dtypes = {}
[docs] @annotate("reduce_dtype_size_fit", color="green", domain="nvt_python") def fit(self, col_selector: ColumnSelector, ddf: dd.DataFrame): return {col: (ddf[col].min(), ddf[col].max()) for col in col_selector.names}
[docs] def fit_finalize(self, dask_stats): self.ranges = dask_stats
[docs] def clear(self): self.dtypes = {} self.ranges = {}
[docs] @annotate("reduce_dtype_size_transform", color="darkgreen", domain="nvt_python") def transform(self, col_selector: ColumnSelector, df: DataFrameType) -> DataFrameType: for col_name, col_dtype in self.dtypes.items(): np_dtype = col_dtype.to_numpy df[col_name] = df[col_name].astype(np_dtype) return df
[docs] def compute_output_schema(self, input_schema, selector, prev_output_schema=None): if not self.ranges: return input_schema output_columns = [] for column, (min_value, max_value) in self.ranges.items(): column = input_schema[column] dtype = column.dtype if column.dtype.element_type.value == "int": for possible_dtype in _INT_DTYPES: dtype_range = np.iinfo(possible_dtype) if min_value >= dtype_range.min and max_value <= dtype_range.max: dtype = possible_dtype break elif column.dtype.element_type.value == "float": dtype = self.float_dtype output_columns.append(column.with_dtype(dtype)) self.dtypes = { column.dtype for column in output_columns} return Schema(output_columns)
transform.__doc__ = Operator.transform.__doc__ compute_output_schema.__doc__ = Operator.compute_output_schema.__doc__ fit.__doc__ = fit_finalize.__doc__ = StatOperator.fit_finalize.__doc__ clear.__doc__ = StatOperator.clear.__doc__