#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import abc
from dataclasses import dataclass
from typing import Dict, Optional, Union
import torch
from merlin_standard_lib import Schema
from merlin_standard_lib.utils.proto_utils import has_field
from ...config.schema import SchemaMixin
from ..typing import TabularData
[docs]class OutputSizeMixin(SchemaMixin, abc.ABC):
[docs]    def build(self, input_size, schema=None, **kwargs):
        self.check_schema(schema=schema)
        self.input_size = input_size
        if schema and not getattr(self, "schema", None):
            self.schema = schema
        return self 
[docs]    def output_size(self, input_size=None):
        input_size = input_size or getattr(self, "input_size", None)
        if not input_size:
            # TODO: log warning here
            return None
        return self.forward_output_size(input_size) 
[docs]    def forward_output_size(self, input_size):
        raise NotImplementedError() 
    def __rrshift__(self, other):
        from ..block.base import right_shift_block
        return right_shift_block(self, other) 
[docs]class LossMixin:
    """Mixin to use for a `torch.Module` that can calculate a loss."""
[docs]    def compute_loss(
        self,
        inputs: Union[torch.Tensor, TabularData],
        targets: Union[torch.Tensor, TabularData],
        compute_metrics: bool = True,
        **kwargs,
    ) -> torch.Tensor:
        """Compute the loss on a batch of data.
        Parameters
        ----------
        inputs: Union[torch.Tensor, TabularData]
            TODO
        targets: Union[torch.Tensor, TabularData]
            TODO
        compute_metrics: bool, default=True
            Boolean indicating whether or not to update the state of the metrics
            (if they are defined).
        """
        raise NotImplementedError()  
[docs]class MetricsMixin:
    """Mixin to use for a `torch.Module` that can calculate metrics."""
[docs]    def calculate_metrics(
        self,
        inputs: Union[torch.Tensor, TabularData],
        targets: Union[torch.Tensor, TabularData],
        mode: str = "val",
        forward=True,
        **kwargs,
    ) -> Dict[str, torch.Tensor]:
        """Calculate metrics on a batch of data, each metric is stateful and this updates the state.
        The state of each metric can be retrieved by calling the `compute_metrics` method.
        Parameters
        ----------
        inputs: Union[torch.Tensor, TabularData]
            TODO
        targets: Union[torch.Tensor, TabularData]
            TODO
        forward: bool, default True
        mode: str, default="val"
        """
        raise NotImplementedError() 
[docs]    def compute_metrics(self, mode: str = None) -> Dict[str, Union[float, torch.Tensor]]:
        """Returns the current state of each metric.
        The state is typically updated each batch by calling the `calculate_metrics` method.
        Parameters
        ----------
        mode: str, default="val"
        Returns
        -------
        Dict[str, Union[float, torch.Tensor]]
        """
        raise NotImplementedError() 
[docs]    def reset_metrics(self):
        """Reset all metrics."""
        raise NotImplementedError()  
[docs]def requires_schema(module):
    module.REQUIRES_SCHEMA = True
    return module 
[docs]def check_gpu(module):
    try:
        return next(module.parameters()).is_cuda
    except StopIteration:
        return False 
[docs]def get_output_sizes_from_schema(schema: Schema, batch_size=-1, max_sequence_length=None):
    sizes = {}
    for feature in schema.feature:
        name = feature.name
        # Sequential or multi-hot feature
        if has_field(feature, "value_count"):
            sizes[name] = torch.Size(
                [
                    batch_size,
                    max_sequence_length if max_sequence_length else feature.value_count.max,
                ]
            )
        elif has_field(feature, "shape"):
            sizes[name] = torch.Size([batch_size] + [d.size for d in feature.shape.dim])
        else:
            sizes[name] = torch.Size([batch_size])
    return sizes 
[docs]def create_output_placeholder(scores, ks):
    return torch.zeros(scores.shape[0], len(ks)).to(device=scores.device, dtype=torch.float32) 
[docs]def one_hot_1d(
    labels: torch.Tensor,
    num_classes: int,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = torch.float32,
) -> torch.Tensor:
    r"""Coverts a 1d label tensor to one-hot representation
    Args:
        labels (torch.Tensor) : tensor with labels of shape :math:`(N, H, W)`,
                                where N is batch size. Each value is an integer
                                representing correct classification.
        num_classes (int): number of classes in labels.
        device (Optional[torch.device]): the desired device of returned tensor.
         Default: if None, uses the current device for the default tensor type
         (see torch.set_default_tensor_type()). device will be the CPU for CPU
         tensor types and the current CUDA device for CUDA tensor types.
        dtype (Optional[torch.dtype]): the desired data type of returned
         tensor. Default: torch.float32
    Returns:
        torch.Tensor: the labels in one hot tensor.
    Examples::
        >>> labels = torch.LongTensor([0, 1, 2, 0])
        >>> one_hot_1d(labels, num_classes=3)
        tensor([[1., 0., 0.],
                [0., 1., 0.],
                [0., 0., 1.],
                [1., 0., 0.],
               ])
    """
    if not torch.is_tensor(labels):
        raise TypeError("Input labels type is not a torch.Tensor. Got {}".format(type(labels)))
    if not len(labels.shape) == 1:
        raise ValueError("Expected tensor should have 1 dim. Got: {}".format(labels.shape))
    if not labels.dtype == torch.int64:
        raise ValueError(
            "labels must be of the same dtype torch.int64. Got: {}".format(labels.dtype)
        )
    if num_classes < 1:
        raise ValueError(
            "The number of classes must be bigger than one." " Got: {}".format(num_classes)
        )
    if device is None:
        device = labels.device
    labels_size = labels.shape[0]
    one_hot = torch.zeros(labels_size, num_classes, device=device, dtype=dtype)
    return one_hot.scatter_(1, labels.unsqueeze(-1), 1.0) 
[docs]class LambdaModule(torch.nn.Module):
    def __init__(self, lambda_fn):
        super().__init__()
        import types
        assert isinstance(lambda_fn, types.LambdaType)
        self.lambda_fn = lambda_fn
[docs]    def forward(self, x):
        return self.lambda_fn(x)