Source code for merlin.models.tf.transforms.tensor
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, Union
import tensorflow as tf
from keras.layers.preprocessing import preprocessing_utils as utils
from merlin.models.tf.core.combinators import TabularBlock
from merlin.models.tf.typing import TabularData
ONE_HOT = utils.ONE_HOT
MULTI_HOT = utils.MULTI_HOT
COUNT = utils.COUNT
[docs]@tf.keras.utils.register_keras_serializable(package="merlin.models")
class ExpandDims(TabularBlock):
"""
Expand dims of selected input tensors.
Example::
inputs = {
"cont_feat1": tf.random.uniform((NUM_ROWS,)),
"cont_feat2": tf.random.uniform((NUM_ROWS,)),
"multi_hot_categ_feat": tf.random.uniform(
(NUM_ROWS, 4), minval=1, maxval=100, dtype=tf.int32
),
}
expand_dims_op = tr.ExpandDims(expand_dims={"cont_feat2": 0, "multi_hot_categ_feat": 1})
expanded_inputs = expand_dims_op(inputs)
"""
[docs] def __init__(self, expand_dims: Union[int, Dict[str, int]] = -1, **kwargs):
"""Instantiates the `ExpandDims` transformation, which allows to expand dims
of the input tensors
Parameters
----------
expand_dims : Union[int, Dict[str, int]], optional, by default -1
Defines which dimensions should be expanded. If an `int` is provided, all input tensors
will have the same dimension expanded. If a `dict` is passed, only features matching
the dict keys will be expanded, in the dimension specified as the dict values.
"""
super().__init__(**kwargs)
self.inputs_expand_dims = expand_dims
[docs] def call(self, inputs: TabularData, **kwargs) -> TabularData:
outputs = {}
for k, v in inputs.items():
if isinstance(self.inputs_expand_dims, int):
outputs[k] = tf.expand_dims(v, self.inputs_expand_dims)
elif isinstance(self.inputs_expand_dims, dict) and k in self.inputs_expand_dims:
expand_dim = self.inputs_expand_dims[k]
outputs[k] = tf.expand_dims(v, expand_dim)
elif self.inputs_expand_dims:
outputs[k] = v
else:
raise ValueError("The expand_dims argument is not valid")
return outputs
[docs] def compute_output_shape(self, input_shape):
return input_shape
def to_dense(tensor, max_seq_length=None):
if isinstance(tensor, tf.RaggedTensor):
if max_seq_length:
shape = [None] * tensor.shape.rank
shape[1] = max_seq_length
return tensor.to_tensor(shape=shape)
result = tensor.to_tensor()
elif isinstance(tensor, tf.SparseTensor):
result = tf.sparse.to_dense(tensor)
elif isinstance(tensor, tf.Tensor):
result = tensor
else:
result = tf.convert_to_tensor(tensor)
return result
def to_sparse(tensor):
if isinstance(tensor, tf.RaggedTensor):
result = tensor.to_sparse()
elif isinstance(tensor, tf.Tensor):
result = tf.sparse.from_dense(tensor)
elif isinstance(tensor, tf.SparseTensor):
result = tensor
else:
raise ValueError(
"Only tf.RaggedTensor and tf.Tensor are acceptable "
f"for converting to tf.SparseTensor, but got a {type(tensor)}"
)
return result