#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional
from merlin_standard_lib import Schema
[docs]class SchemaMixin:
    REQUIRES_SCHEMA = False
[docs]    def set_schema(self, schema=None):
        self.check_schema(schema=schema)
        if schema and not getattr(self, "schema", None):
            self._schema = schema
        return self 
    @property
    def schema(self) -> Optional[Schema]:
        return getattr(self, "_schema", None)
    @schema.setter
    def schema(self, value):
        if value:
            self.set_schema(value)
        else:
            self._schema = value
[docs]    def check_schema(self, schema=None):
        if self.REQUIRES_SCHEMA and not getattr(self, "schema", None) and not schema:
            raise ValueError(f"{self.__class__.__name__} requires a schema.") 
    def __call__(self, *args, **kwargs):
        self.check_schema()
        return super().__call__(*args, **kwargs)
    def _maybe_set_schema(self, input, schema):
        if input and getattr(input, "set_schema"):
            input.set_schema(schema)
[docs]    def get_padding_mask_from_item_id(self, inputs, pad_token=0):
        item_id_inputs = self.get_item_ids_from_inputs(inputs)
        if len(item_id_inputs.shape) != 2:
            raise ValueError(
                "To extract the padding mask from item id tensor "
                "it is expected to have 2 dims, but it has {} dims.".format(item_id_inputs.shape)
            )
        return self.get_item_ids_from_inputs(inputs) != pad_token  
[docs]def requires_schema(module):
    module.REQUIRES_SCHEMA = True
    return module