Source code for merlin_standard_lib.schema.schema

# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import collections
import json
import os
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union

from google.protobuf import json_format, text_format
from google.protobuf.message import Message as ProtoMessage
from merlin.models.utils import schema_utils as mm_schema_utils
from merlin.schema import Schema as CoreSchema
from merlin.schema import Tags, TagSet, TagsType
from merlin.schema.io import proto_utils

try:
    from functools import cached_property  # type: ignore
except ImportError:
    # polyfill cached_property for python <= 3.7 (using lru_cache which was introduced in python3.2)
    from functools import lru_cache

    cached_property = lambda func: property(lru_cache()(func))  # type: ignore  # noqa

import betterproto  # noqa
from betterproto import Message as BetterProtoMessage

from ..proto.schema_bp import *  # noqa
from ..proto.schema_bp import (
    Annotation,
    Feature,
    FeatureType,
    FixedShape,
    FixedShapeDim,
    FloatDomain,
    IntDomain,
    ValueCount,
    ValueCountList,
    _Schema,
)

ProtoMessageType = TypeVar("ProtoMessageType", bound=BetterProtoMessage)


def _parse_shape_and_value_count(shape, value_count) -> Dict[str, Any]:
    output: Dict[str, Union[ValueCount, ValueCountList, FixedShape]] = {}
    if shape:
        output["shape"] = FixedShape([FixedShapeDim(d) for d in shape])

    if value_count:
        if isinstance(value_count, ValueCount):
            output["value_count"] = value_count
        elif isinstance(value_count, ValueCountList):
            output["value_counts"] = value_count
        else:
            raise ValueError("Unknown value_count type.")

    return output


[docs]class ColumnSchema(Feature):
[docs] @classmethod def create_categorical( cls, name: str, num_items: int, shape: Optional[Union[Tuple[int, ...], List[int]]] = None, value_count: Optional[Union[ValueCount, ValueCountList]] = None, min_index: int = 0, tags: Optional[TagsType] = None, **kwargs, ) -> "ColumnSchema": _tags: List[str] = [t.value for t in TagSet(tags or [])] extra = _parse_shape_and_value_count(shape, value_count) int_domain = IntDomain(name=name, min=min_index, max=num_items, is_categorical=True) _tags = list(set(_tags + [Tags.CATEGORICAL.value])) extra["type"] = FeatureType.INT return cls(name=name, int_domain=int_domain, **extra, **kwargs).with_tags(_tags)
[docs] @classmethod def create_continuous( cls, name: str, is_float: bool = True, min_value: Optional[Union[int, float]] = None, max_value: Optional[Union[int, float]] = None, disallow_nan: bool = False, disallow_inf: bool = False, is_embedding: bool = False, shape: Optional[Union[Tuple[int, ...], List[int]]] = None, value_count: Optional[Union[ValueCount, ValueCountList]] = None, tags: Optional[TagsType] = None, **kwargs, ) -> "ColumnSchema": _tags: List[str] = [t.value for t in TagSet(tags or [])] extra = _parse_shape_and_value_count(shape, value_count) if min_value is not None and max_value is not None: if is_float: extra["float_domain"] = FloatDomain( name=name, min=float(min_value), max=float(max_value), disallow_nan=disallow_nan, disallow_inf=disallow_inf, is_embedding=is_embedding, ) else: extra["int_domain"] = IntDomain( name=name, min=int(min_value), max=int(max_value), is_categorical=False ) extra["type"] = FeatureType.FLOAT if is_float else FeatureType.INT _tags = list(set(_tags + [Tags.CONTINUOUS.value])) return cls(name=name, **extra, **kwargs).with_tags(_tags)
[docs] def copy(self, **kwargs) -> "ColumnSchema": return proto_utils.copy_better_proto_message(self, **kwargs)
[docs] def with_name(self, name: str): return self.copy(name=name)
[docs] def with_tags(self, tags: TagsType) -> "ColumnSchema": tags = [str(t) for t in tags] output = self.copy() if self.annotation: output.annotation.tag = list(set(list(self.annotation.tag) + tags)) else: output.annotation = Annotation(tag=tags) return output
[docs] def with_tags_based_on_properties( self, using_value_count=True, using_domain=True ) -> "ColumnSchema": extra_tags = [] if using_value_count and proto_utils.has_field(self, "value_count"): extra_tags.append(str(Tags.LIST)) if using_domain and proto_utils.has_field(self, "int_domain"): if self.int_domain.is_categorical: extra_tags.append(str(Tags.CATEGORICAL)) else: extra_tags.append(str(Tags.CONTINUOUS)) if using_domain and proto_utils.has_field(self, "float_domain"): extra_tags.append(str(Tags.CONTINUOUS)) return self.with_tags(extra_tags) if extra_tags else self.copy()
[docs] def with_properties(self, properties: Dict[str, Union[str, int, float]]) -> "ColumnSchema": output = self.copy() if output.annotation: if len(output.annotation.extra_metadata) > 0: output.annotation.extra_metadata[0].update(properties) else: output.annotation.extra_metadata = [properties] else: output.annotation = Annotation(extra_metadata=[properties]) return output
[docs] def to_proto_text(self) -> str: from tensorflow_metadata.proto.v0 import schema_pb2 return proto_utils.better_proto_to_proto_text(self, schema_pb2.Feature())
@property def tags(self): return self.annotation.tag @property def properties(self) -> Dict[str, Union[str, float, int]]: if self.annotation.extra_metadata: properties: Dict[str, Union[str, float, int]] = self.annotation.extra_metadata[0] return properties return {} def _set_tags(self, tags: List[str]): if self.annotation: self.annotation.tag = list(set(list(self.annotation.tag) + tags)) else: self.annotation = Annotation(tag=tags) def __str__(self) -> str: return self.name def __eq__(self, other) -> bool: if not isinstance(other, ColumnSchema): return NotImplemented return self.to_dict() == other.to_dict()
ColumnSchemaOrStr = Union[ColumnSchema, str] FilterT = TypeVar("FilterT")
[docs]class Schema(_Schema): """A collection of column schemas for a dataset.""" feature: List["ColumnSchema"] = betterproto.message_field(1)
[docs] @classmethod def create( cls, column_schemas: Optional[ Union[List[ColumnSchemaOrStr], Dict[str, ColumnSchemaOrStr]] ] = None, **kwargs, ): column_schemas = column_schemas or [] if isinstance(column_schemas, dict): column_schemas = list(column_schemas.values()) features: List[ColumnSchema] = [] if isinstance(column_schemas, list): for column_schema in column_schemas: if isinstance(column_schema, str): features.append(ColumnSchema(column_schema)) else: features.append(column_schema) else: raise TypeError("The `column_schemas` parameter must be a list or dict.") return cls(feature=features, **kwargs)
[docs] def with_tags_based_on_properties(self, using_value_count=True, using_domain=True) -> "Schema": column_schemas = [] for column in self.column_schemas: column_schemas.append( column.with_tags_based_on_properties( using_value_count=using_value_count, using_domain=using_domain ) ) return Schema(column_schemas)
[docs] def apply(self, selector) -> "Schema": if selector and selector.names: return self.select_by_name(selector.names) else: return self
[docs] def apply_inverse(self, selector) -> "Schema": if selector: output_schema: Schema = self - self.select_by_name(selector.names) return output_schema else: return self
[docs] def filter_columns_from_dict(self, input_dict): filtered_dict = {} for key, val in input_dict.items(): if key in self.column_names: filtered_dict[key] = val return filtered_dict
[docs] def select_by_type(self, to_select) -> "Schema": if not isinstance(to_select, (list, tuple)) and not callable(to_select): to_select = [to_select] def collection_filter_fn(type): return type in to_select output: Schema = self._filter_column_schemas( to_select, collection_filter_fn, lambda x: x.type ) return output
[docs] def remove_by_type(self, to_remove) -> "Schema": if not isinstance(to_remove, (list, tuple)) and not callable(to_remove): to_remove = [to_remove] def collection_filter_fn(type): return type in to_remove output: Schema = self._filter_column_schemas( to_remove, collection_filter_fn, lambda x: x.type, negate=True ) return output
[docs] def select_by_tag(self, to_select) -> "Schema": if not isinstance(to_select, (list, tuple)) and not callable(to_select): to_select = [to_select] if callable(to_select): return self._filter_column_schemas(to_select, lambda x: False, lambda x: x.tags) else: # Schema.tags always returns a List[str] with the tag values, so if the user wants to # filter using the Tags Enum, we need to convert those to their string value to_select = [tag.value if isinstance(tag, Tags) else tag for tag in to_select] def collection_filter_fn(column_names: List[str]): return all(x in column_names for x in to_select) return self._filter_column_schemas(to_select, collection_filter_fn, lambda x: x.tags)
[docs] def remove_by_tag(self, to_remove) -> "Schema": if not isinstance(to_remove, (list, tuple)) and not callable(to_remove): to_remove = [to_remove] def collection_filter_fn(column_tags): return all(x in column_tags for x in to_remove) return self._filter_column_schemas( to_remove, collection_filter_fn, lambda x: x.tags, negate=True )
[docs] def select_by_name(self, to_select) -> "Schema": if not isinstance(to_select, (list, tuple)) and not callable(to_select): to_select = [to_select] def collection_filter_fn(column_name): return column_name in to_select output: Schema = self._filter_column_schemas( to_select, collection_filter_fn, lambda x: x.name ) return output
[docs] def remove_by_name(self, to_remove) -> "Schema": if not isinstance(to_remove, (list, tuple)) and not callable(to_remove): to_remove = [to_remove] def collection_filter_fn(column_name): return column_name in to_remove return self._filter_column_schemas( to_remove, collection_filter_fn, lambda x: x.name, negate=True )
[docs] def map_column_schemas(self, map_fn: Callable[[ColumnSchema], ColumnSchema]) -> "Schema": output_schemas = [] for column_schema in self.column_schemas: output_schemas.append(map_fn(column_schema)) return Schema(output_schemas)
[docs] def filter_column_schemas( self, filter_fn: Callable[[ColumnSchema], bool], negate=False ) -> "Schema": selected_schemas = [] for column_schema in self.column_schemas: if self._check_column_schema(column_schema, filter_fn, negate=negate): selected_schemas.append(column_schema) return Schema(selected_schemas)
@property def column_names(self) -> List[str]: return [f.name for f in self.feature] @property def column_schemas(self) -> Sequence[ColumnSchema]: return self.feature @cached_property def item_id_column_name(self): item_id_col = self.select_by_tag(Tags.ITEM_ID) if len(item_id_col.column_names) == 0: raise ValueError("There is no column tagged as item id.") return item_id_col.column_names[0]
[docs] def from_json(self, value: Union[str, bytes]) -> "Schema": if os.path.isfile(value): with open(value, "rb") as f: value = f.read() return super().from_json(value)
[docs] def to_proto_text(self) -> str: from tensorflow_metadata.proto.v0 import schema_pb2 return proto_utils.better_proto_to_proto_text(self, schema_pb2.Schema())
[docs] def from_proto_text(self, path_or_proto_text: str) -> "Schema": from tensorflow_metadata.proto.v0 import schema_pb2 return _proto_text_to_better_proto(self, path_or_proto_text, schema_pb2.Schema())
[docs] def copy(self, **kwargs) -> "Schema": return proto_utils.copy_better_proto_message(self, **kwargs)
[docs] def add(self, other, allow_overlap=True) -> "Schema": if isinstance(other, str): other = Schema.create([other]) elif isinstance(other, collections.abc.Sequence): # type: ignore other = Schema(other) if not allow_overlap: # check if there are any columns with the same name in both column groups overlap = set(self.column_names).intersection(other.column_names) if overlap: raise ValueError(f"duplicate column names found: {overlap}") new_columns = self.column_schemas + other.column_schemas else: self_column_dict = {col.name: col for col in self.column_schemas} other_column_dict = {col.name: col for col in other.column_schemas} new_columns = [col for col in self.column_schemas] for key, val in other_column_dict.items(): maybe_duplicate = self_column_dict.get(key, None) if maybe_duplicate: merged_col = maybe_duplicate.with_tags(val.tags) new_columns[new_columns.index(maybe_duplicate)] = merged_col else: new_columns.append(val) return Schema(new_columns)
def _filter_column_schemas( self, to_filter: Union[list, tuple, Callable[[FilterT], bool]], collection_filter_fn: Callable[[FilterT], bool], column_select_fn: Callable[[ColumnSchema], FilterT], negate=False, ) -> "Schema": if isinstance(to_filter, (list, tuple)): check_fn = collection_filter_fn elif callable(to_filter): check_fn = to_filter else: raise ValueError(f"Expected either a collection or function, got: {to_filter}.") selected_schemas = [] for column_schema in self.column_schemas: if self._check_column_schema(column_select_fn(column_schema), check_fn, negate=negate): selected_schemas.append(column_schema) return Schema(selected_schemas) def _check_column_schema( self, inputs: FilterT, filter_fn: Callable[[FilterT], bool], negate=False ) -> bool: check = filter_fn(inputs) if check and not negate: return True elif not check and negate: return True return False def __iter__(self): return iter(self.column_schemas) def __len__(self): return len(self.column_schemas) def __repr__(self): return str( [ col_schema.to_dict(casing=betterproto.Casing.SNAKE) for col_schema in self.column_schemas ] ) def __eq__(self, other): if not isinstance(other, Schema) or len(self.column_schemas) != len(other.column_schemas): return False return sorted(self.column_schemas, key=lambda x: x.name) == sorted( other.column_schemas, key=lambda x: x.name ) def __add__(self, other): return self.add(other, allow_overlap=True) def __radd__(self, other): return self.__add__(other) def __sub__(self, other): if other is None: return self if not isinstance(other, Schema): raise TypeError(f"unsupported operand type(s) for -: 'Schema' and {type(other)}") result = Schema(self.column_schemas) for key in other.column_schemas: if key in self.column_schemas: result.column_schemas.pop(key, None) return result
def _proto_text_to_better_proto( better_proto_message: ProtoMessageType, path_proto_text: str, message: ProtoMessage ) -> ProtoMessageType: proto_text = path_proto_text if os.path.isfile(proto_text): with open(path_proto_text, "r") as f: proto_text = f.read() proto = text_format.Parse(proto_text, message) # This is a hack because as of now we can't parse the Any representation. # TODO: Fix this. d = json_format.MessageToDict(proto) for f in d["feature"]: if "extraMetadata" in f["annotation"]: # type: ignore extra_metadata = f["annotation"].pop("extraMetadata") # type: ignore f["annotation"]["comment"] = [json.dumps(extra_metadata[0]["value"])] # type: ignore json_str = json_format.MessageToJson(json_format.ParseDict(d, message)) return better_proto_message.__class__().from_json(json_str)
[docs]def categorical_cardinalities(schema) -> Dict[str, int]: if isinstance(schema, CoreSchema): return mm_schema_utils.categorical_cardinalities(schema) outputs = {} for col in schema: if col.int_domain and col.int_domain.is_categorical: outputs[col.name] = col.int_domain.max + 1 return outputs