#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import abc
from typing import Dict, Union
import tensorflow as tf
from ..typing import TabularData
[docs]class LossMixin(abc.ABC):
"""Mixin to use for Keras Layers that can calculate a loss."""
[docs] def compute_loss(
self,
inputs: Union[tf.Tensor, TabularData],
targets: Union[tf.Tensor, TabularData],
compute_metrics=True,
training: bool = False,
**kwargs,
) -> tf.Tensor:
"""Compute the loss on a batch of data.
Parameters
----------
inputs: Union[torch.Tensor, TabularData]
TODO
targets: Union[torch.Tensor, TabularData]
TODO
training: bool, default=False
"""
raise NotImplementedError()
[docs]class MetricsMixin(abc.ABC):
"""Mixin to use for Keras Layers that can calculate metrics."""
[docs] def calculate_metrics(
self,
inputs: Union[tf.Tensor, TabularData],
targets: Union[tf.Tensor, TabularData],
mode: str = "val",
forward=True,
**kwargs,
) -> Dict[str, Union[Dict[str, tf.Tensor], tf.Tensor]]:
"""Calculate metrics on a batch of data, each metric is stateful and this updates the state.
The state of each metric can be retrieved by calling the `metric_results` method.
Parameters
----------
inputs: Union[tf.Tensor, TabularData]
TODO
targets: Union[tf.Tensor, TabularData]
TODO
forward: bool, default True
mode: str, default="val"
"""
raise NotImplementedError()
[docs] def metric_results(self, mode: str = None) -> Dict[str, Union[float, tf.Tensor]]:
"""Returns the current state of each metric.
The state is typically updated each batch by calling the `calculate_metrics` method.
Parameters
----------
mode: str, default="val"
Returns
-------
Dict[str, Union[float, tf.Tensor]]
"""
raise NotImplementedError()
[docs] def reset_metrics(self):
"""Reset all metrics."""
raise NotImplementedError()
[docs]def get_output_sizes_from_schema(schema, batch_size=0, max_sequence_length=None):
sizes = {}
for feature in schema.feature:
name = feature.name
if feature.HasField("value_count"):
sizes[name] = tf.TensorShape(
[
batch_size,
max_sequence_length if max_sequence_length else feature.value_count.max,
]
)
elif feature.HasField("shape"):
sizes[name] = tf.TensorShape([batch_size] + [d.size for d in feature.shape.dim])
else:
sizes[name] = tf.TensorShape([batch_size, 1])
return sizes
[docs]def get_tf_main_layer(hf_model):
"""
Extract serializable custom keras layer `TF*MainLayer` from the HF model
"""
main_layer = [v for _, v in hf_model.__dict__.items() if isinstance(v, tf.keras.layers.Layer)][
0
]
return main_layer
[docs]def maybe_serialize_keras_objects(
self,
config,
maybe_serialize_keys,
):
for key in maybe_serialize_keys:
maybe_value = getattr(self, key, None)
if maybe_value is not None:
if isinstance(maybe_value, dict):
config[key] = {
k: tf.keras.utils.serialize_keras_object(v) for k, v in maybe_value.items()
}
elif isinstance(maybe_value, list):
config[key] = [tf.keras.utils.serialize_keras_object(v) for v in maybe_value]
else:
config[key] = tf.keras.utils.serialize_keras_object(maybe_value)
return config
[docs]def maybe_deserialize_keras_objects(
config, to_deserialize, deserialize_fn=tf.keras.utils.deserialize_keras_object
):
if isinstance(to_deserialize, list):
to_deserialize = {k: deserialize_fn for k in to_deserialize}
custom_objects = {}
for key, fn in to_deserialize.items():
maybe_val = config.get(key, None)
if maybe_val:
if isinstance(maybe_val, list):
config[key] = [fn(v, custom_objects=custom_objects) for v in maybe_val]
else:
config[key] = fn(maybe_val, custom_objects=custom_objects)
return config
[docs]def create_output_placeholder(scores, ks):
return tf.Variable(tf.zeros([tf.shape(scores)[0], len(ks)], tf.float32))
[docs]def gather_torch_like(labels, indices, max_k):
# gather_indices = []
gather_indices = tf.TensorArray(tf.int32, size=tf.shape(indices)[0])
for i in range(tf.shape(indices)[0]):
gather_indices = gather_indices.write(
i,
tf.concat(
[i * tf.ones((max_k, 1), tf.int32), tf.expand_dims(indices[i, :], -1)], axis=1
),
)
all_indices = gather_indices.stack()
labels = tf.reshape(tf.gather_nd(labels, all_indices), tf.shape(indices))
return labels