transformers4rec.tf.utils package
Submodules
transformers4rec.tf.utils.repr_utils module
transformers4rec.tf.utils.schema_utils module
transformers4rec.tf.utils.testing_utils module
-
transformers4rec.tf.utils.testing_utils.
assert_body_works_in_model
(data, inputs, body, run_eagerly)[source]
transformers4rec.tf.utils.tf_utils module
-
class
transformers4rec.tf.utils.tf_utils.
LossMixin
[source] Bases:
abc.ABC
Mixin to use for Keras Layers that can calculate a loss.
-
compute_loss
(inputs: Union[tensorflow.python.framework.ops.Tensor, Dict[str, tensorflow.python.framework.ops.Tensor]], targets: Union[tensorflow.python.framework.ops.Tensor, Dict[str, tensorflow.python.framework.ops.Tensor]], compute_metrics=True, training: bool = False, **kwargs) → tensorflow.python.framework.ops.Tensor[source] Compute the loss on a batch of data.
- Parameters
inputs (Union[torch.Tensor, TabularData]) – TODO
targets (Union[torch.Tensor, TabularData]) – TODO
training (bool, default=False) –
-
-
class
transformers4rec.tf.utils.tf_utils.
MetricsMixin
[source] Bases:
abc.ABC
Mixin to use for Keras Layers that can calculate metrics.
-
calculate_metrics
(inputs: Union[tensorflow.python.framework.ops.Tensor, Dict[str, tensorflow.python.framework.ops.Tensor]], targets: Union[tensorflow.python.framework.ops.Tensor, Dict[str, tensorflow.python.framework.ops.Tensor]], mode: str = 'val', forward=True, **kwargs) → Dict[str, Union[Dict[str, tensorflow.python.framework.ops.Tensor], tensorflow.python.framework.ops.Tensor]][source] Calculate metrics on a batch of data, each metric is stateful and this updates the state.
The state of each metric can be retrieved by calling the metric_results method.
-
-
transformers4rec.tf.utils.tf_utils.
get_output_sizes_from_schema
(schema, batch_size=0, max_sequence_length=None)[source]
-
transformers4rec.tf.utils.tf_utils.
get_tf_main_layer
(hf_model)[source] Extract serializable custom keras layer TF*MainLayer from the HF model
-
transformers4rec.tf.utils.tf_utils.
maybe_serialize_keras_objects
(self, config, maybe_serialize_keys)[source]