# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import collections
import json
import os
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
from google.protobuf import json_format, text_format
from google.protobuf.message import Message as ProtoMessage
from merlin.models.utils import schema_utils as mm_schema_utils
from merlin.schema import Schema as CoreSchema
from merlin.schema import Tags, TagSet, TagsType
from merlin.schema.io import proto_utils
try:
from functools import cached_property # type: ignore
except ImportError:
# polyfill cached_property for python <= 3.7 (using lru_cache which was introduced in python3.2)
from functools import lru_cache
cached_property = lambda func: property(lru_cache()(func)) # type: ignore # noqa
import betterproto # noqa
from betterproto import Message as BetterProtoMessage
from ..proto.schema_bp import * # noqa
from ..proto.schema_bp import (
Annotation,
Feature,
FeatureType,
FixedShape,
FixedShapeDim,
FloatDomain,
IntDomain,
ValueCount,
ValueCountList,
_Schema,
)
ProtoMessageType = TypeVar("ProtoMessageType", bound=BetterProtoMessage)
def _parse_shape_and_value_count(shape, value_count) -> Dict[str, Any]:
output: Dict[str, Union[ValueCount, ValueCountList, FixedShape]] = {}
if shape:
output["shape"] = FixedShape([FixedShapeDim(d) for d in shape])
if value_count:
if isinstance(value_count, ValueCount):
output["value_count"] = value_count
elif isinstance(value_count, ValueCountList):
output["value_counts"] = value_count
else:
raise ValueError("Unknown value_count type.")
return output
[docs]class ColumnSchema(Feature):
[docs] @classmethod
def create_categorical(
cls,
name: str,
num_items: int,
shape: Optional[Union[Tuple[int, ...], List[int]]] = None,
value_count: Optional[Union[ValueCount, ValueCountList]] = None,
min_index: int = 0,
tags: Optional[TagsType] = None,
**kwargs,
) -> "ColumnSchema":
_tags: List[str] = [t.value for t in TagSet(tags or [])]
extra = _parse_shape_and_value_count(shape, value_count)
int_domain = IntDomain(name=name, min=min_index, max=num_items, is_categorical=True)
_tags = list(set(_tags + [Tags.CATEGORICAL.value]))
extra["type"] = FeatureType.INT
return cls(name=name, int_domain=int_domain, **extra, **kwargs).with_tags(_tags)
[docs] @classmethod
def create_continuous(
cls,
name: str,
is_float: bool = True,
min_value: Optional[Union[int, float]] = None,
max_value: Optional[Union[int, float]] = None,
disallow_nan: bool = False,
disallow_inf: bool = False,
is_embedding: bool = False,
shape: Optional[Union[Tuple[int, ...], List[int]]] = None,
value_count: Optional[Union[ValueCount, ValueCountList]] = None,
tags: Optional[TagsType] = None,
**kwargs,
) -> "ColumnSchema":
_tags: List[str] = [t.value for t in TagSet(tags or [])]
extra = _parse_shape_and_value_count(shape, value_count)
if min_value is not None and max_value is not None:
if is_float:
extra["float_domain"] = FloatDomain(
name=name,
min=float(min_value),
max=float(max_value),
disallow_nan=disallow_nan,
disallow_inf=disallow_inf,
is_embedding=is_embedding,
)
else:
extra["int_domain"] = IntDomain(
name=name, min=int(min_value), max=int(max_value), is_categorical=False
)
extra["type"] = FeatureType.FLOAT if is_float else FeatureType.INT
_tags = list(set(_tags + [Tags.CONTINUOUS.value]))
return cls(name=name, **extra, **kwargs).with_tags(_tags)
[docs] def copy(self, **kwargs) -> "ColumnSchema":
return proto_utils.copy_better_proto_message(self, **kwargs)
[docs] def with_name(self, name: str):
return self.copy(name=name)
[docs] def with_properties(self, properties: Dict[str, Union[str, int, float]]) -> "ColumnSchema":
output = self.copy()
if output.annotation:
if len(output.annotation.extra_metadata) > 0:
output.annotation.extra_metadata[0].update(properties)
else:
output.annotation.extra_metadata = [properties]
else:
output.annotation = Annotation(extra_metadata=[properties])
return output
[docs] def to_proto_text(self) -> str:
from tensorflow_metadata.proto.v0 import schema_pb2
return proto_utils.better_proto_to_proto_text(self, schema_pb2.Feature())
@property
def tags(self):
return self.annotation.tag
@property
def properties(self) -> Dict[str, Union[str, float, int]]:
if self.annotation.extra_metadata:
properties: Dict[str, Union[str, float, int]] = self.annotation.extra_metadata[0]
return properties
return {}
def _set_tags(self, tags: List[str]):
if self.annotation:
self.annotation.tag = list(set(list(self.annotation.tag) + tags))
else:
self.annotation = Annotation(tag=tags)
def __str__(self) -> str:
return self.name
def __eq__(self, other) -> bool:
if not isinstance(other, ColumnSchema):
return NotImplemented
return self.to_dict() == other.to_dict()
ColumnSchemaOrStr = Union[ColumnSchema, str]
FilterT = TypeVar("FilterT")
[docs]class Schema(_Schema):
"""A collection of column schemas for a dataset."""
feature: List["ColumnSchema"] = betterproto.message_field(1)
[docs] @classmethod
def create(
cls,
column_schemas: Optional[
Union[List[ColumnSchemaOrStr], Dict[str, ColumnSchemaOrStr]]
] = None,
**kwargs,
):
column_schemas = column_schemas or []
if isinstance(column_schemas, dict):
column_schemas = list(column_schemas.values())
features: List[ColumnSchema] = []
if isinstance(column_schemas, list):
for column_schema in column_schemas:
if isinstance(column_schema, str):
features.append(ColumnSchema(column_schema))
else:
features.append(column_schema)
else:
raise TypeError("The `column_schemas` parameter must be a list or dict.")
return cls(feature=features, **kwargs)
[docs] def apply(self, selector) -> "Schema":
if selector and selector.names:
return self.select_by_name(selector.names)
else:
return self
[docs] def apply_inverse(self, selector) -> "Schema":
if selector:
output_schema: Schema = self - self.select_by_name(selector.names)
return output_schema
else:
return self
[docs] def filter_columns_from_dict(self, input_dict):
filtered_dict = {}
for key, val in input_dict.items():
if key in self.column_names:
filtered_dict[key] = val
return filtered_dict
[docs] def select_by_type(self, to_select) -> "Schema":
if not isinstance(to_select, (list, tuple)) and not callable(to_select):
to_select = [to_select]
def collection_filter_fn(type):
return type in to_select
output: Schema = self._filter_column_schemas(
to_select, collection_filter_fn, lambda x: x.type
)
return output
[docs] def remove_by_type(self, to_remove) -> "Schema":
if not isinstance(to_remove, (list, tuple)) and not callable(to_remove):
to_remove = [to_remove]
def collection_filter_fn(type):
return type in to_remove
output: Schema = self._filter_column_schemas(
to_remove, collection_filter_fn, lambda x: x.type, negate=True
)
return output
[docs] def select_by_tag(self, to_select) -> "Schema":
if not isinstance(to_select, (list, tuple)) and not callable(to_select):
to_select = [to_select]
if callable(to_select):
return self._filter_column_schemas(to_select, lambda x: False, lambda x: x.tags)
else:
# Schema.tags always returns a List[str] with the tag values, so if the user wants to
# filter using the Tags Enum, we need to convert those to their string value
if not isinstance(to_select, (list, tuple)):
to_select = [to_select]
to_select = TagSet(to_select)
def collection_filter_fn(column_names: List[str]):
return all(x in column_names for x in to_select)
return self._filter_column_schemas(
list(to_select), collection_filter_fn, lambda x: TagSet(x.tags)
)
[docs] def remove_by_tag(self, to_remove) -> "Schema":
if not isinstance(to_remove, (list, tuple)) and not callable(to_remove):
to_remove = [to_remove]
to_remove = TagSet(to_remove)
def collection_filter_fn(column_tags):
return all(x in column_tags for x in to_remove)
return self._filter_column_schemas(
list(to_remove), collection_filter_fn, lambda x: TagSet(x.tags), negate=True
)
[docs] def select_by_name(self, to_select) -> "Schema":
if not isinstance(to_select, (list, tuple)) and not callable(to_select):
to_select = [to_select]
def collection_filter_fn(column_name):
return column_name in to_select
output: Schema = self._filter_column_schemas(
to_select, collection_filter_fn, lambda x: x.name
)
return output
[docs] def remove_by_name(self, to_remove) -> "Schema":
if not isinstance(to_remove, (list, tuple)) and not callable(to_remove):
to_remove = [to_remove]
def collection_filter_fn(column_name):
return column_name in to_remove
return self._filter_column_schemas(
to_remove, collection_filter_fn, lambda x: x.name, negate=True
)
[docs] def map_column_schemas(self, map_fn: Callable[[ColumnSchema], ColumnSchema]) -> "Schema":
output_schemas = []
for column_schema in self.column_schemas:
output_schemas.append(map_fn(column_schema))
return Schema(output_schemas)
[docs] def filter_column_schemas(
self, filter_fn: Callable[[ColumnSchema], bool], negate=False
) -> "Schema":
selected_schemas = []
for column_schema in self.column_schemas:
if self._check_column_schema(column_schema, filter_fn, negate=negate):
selected_schemas.append(column_schema)
return Schema(selected_schemas)
@property
def column_names(self) -> List[str]:
return [f.name for f in self.feature]
@property
def column_schemas(self) -> Sequence[ColumnSchema]:
return self.feature
@cached_property
def item_id_column_name(self):
item_id_col = self.select_by_tag(Tags.ITEM_ID)
if len(item_id_col.column_names) == 0:
raise ValueError("There is no column tagged as item id.")
return item_id_col.column_names[0]
[docs] def from_json(self, value: Union[str, bytes]) -> "Schema":
if os.path.isfile(value):
with open(value, "rb") as f:
value = f.read()
return super().from_json(value)
[docs] def to_proto_text(self) -> str:
from tensorflow_metadata.proto.v0 import schema_pb2
return proto_utils.better_proto_to_proto_text(self, schema_pb2.Schema())
[docs] def from_proto_text(self, path_or_proto_text: str) -> "Schema":
from tensorflow_metadata.proto.v0 import schema_pb2
return _proto_text_to_better_proto(self, path_or_proto_text, schema_pb2.Schema())
[docs] def copy(self, **kwargs) -> "Schema":
return proto_utils.copy_better_proto_message(self, **kwargs)
[docs] def add(self, other, allow_overlap=True) -> "Schema":
if isinstance(other, str):
other = Schema.create([other])
elif isinstance(other, collections.abc.Sequence): # type: ignore
other = Schema(other)
if not allow_overlap:
# check if there are any columns with the same name in both column groups
overlap = set(self.column_names).intersection(other.column_names)
if overlap:
raise ValueError(f"duplicate column names found: {overlap}")
new_columns = self.column_schemas + other.column_schemas
else:
self_column_dict = {col.name: col for col in self.column_schemas}
other_column_dict = {col.name: col for col in other.column_schemas}
new_columns = [col for col in self.column_schemas]
for key, val in other_column_dict.items():
maybe_duplicate = self_column_dict.get(key, None)
if maybe_duplicate:
merged_col = maybe_duplicate.with_tags(val.tags)
new_columns[new_columns.index(maybe_duplicate)] = merged_col
else:
new_columns.append(val)
return Schema(new_columns)
def _filter_column_schemas(
self,
to_filter: Union[list, tuple, Callable[[FilterT], bool]],
collection_filter_fn: Callable[[FilterT], bool],
column_select_fn: Callable[[ColumnSchema], FilterT],
negate=False,
) -> "Schema":
if isinstance(to_filter, (list, tuple)):
check_fn = collection_filter_fn
elif callable(to_filter):
check_fn = to_filter
else:
raise ValueError(f"Expected either a collection or function, got: {to_filter}.")
selected_schemas = []
for column_schema in self.column_schemas:
if self._check_column_schema(column_select_fn(column_schema), check_fn, negate=negate):
selected_schemas.append(column_schema)
return Schema(selected_schemas)
def _check_column_schema(
self, inputs: FilterT, filter_fn: Callable[[FilterT], bool], negate=False
) -> bool:
check = filter_fn(inputs)
if check and not negate:
return True
elif not check and negate:
return True
return False
def __iter__(self):
return iter(self.column_schemas)
def __len__(self):
return len(self.column_schemas)
def __repr__(self):
return str(
[
col_schema.to_dict(casing=betterproto.Casing.SNAKE)
for col_schema in self.column_schemas
]
)
def __eq__(self, other):
if not isinstance(other, Schema) or len(self.column_schemas) != len(other.column_schemas):
return False
return sorted(self.column_schemas, key=lambda x: x.name) == sorted(
other.column_schemas, key=lambda x: x.name
)
def __add__(self, other):
return self.add(other, allow_overlap=True)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
if other is None:
return self
if not isinstance(other, Schema):
raise TypeError(f"unsupported operand type(s) for -: 'Schema' and {type(other)}")
result = Schema(self.column_schemas)
for key in other.column_schemas:
if key in self.column_schemas:
result.column_schemas.pop(key, None)
return result
def _proto_text_to_better_proto(
better_proto_message: ProtoMessageType, path_proto_text: str, message: ProtoMessage
) -> ProtoMessageType:
proto_text = path_proto_text
if os.path.isfile(proto_text):
with open(path_proto_text, "r") as f:
proto_text = f.read()
proto = text_format.Parse(proto_text, message)
# This is a hack because as of now we can't parse the Any representation.
# TODO: Fix this.
d = json_format.MessageToDict(proto)
for f in d["feature"]:
if "extraMetadata" in f["annotation"]: # type: ignore
extra_metadata = f["annotation"].pop("extraMetadata") # type: ignore
f["annotation"]["comment"] = [json.dumps(extra_metadata[0]["value"])] # type: ignore
json_str = json_format.MessageToJson(json_format.ParseDict(d, message))
return better_proto_message.__class__().from_json(json_str)
[docs]def categorical_cardinalities(schema) -> Dict[str, int]:
if isinstance(schema, CoreSchema):
return mm_schema_utils.categorical_cardinalities(schema)
outputs = {}
for col in schema:
if col.int_domain and col.int_domain.is_categorical:
outputs[col.name] = col.int_domain.max + 1
return outputs