#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import pathlib
import tempfile
# this needs to be before any modules that import protobuf
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
import tensorflow as tf # noqa
from merlin.core.protocols import Transformable # noqa
from merlin.dag import ColumnSelector # noqa
from merlin.schema import ColumnSchema, Schema # noqa
from merlin.systems.dag.ops.operator import InferenceOperator # noqa
[docs]class PredictTensorflow(InferenceOperator):
"""TensorFlow Model Prediction Operator."""
[docs] def __init__(self, model_or_path, custom_objects: dict = None, backend="tensorflow"):
"""
Instantiate a PredictTensorflow inference operator.
Parameters
----------
model_or_path : Tensorflow model or string
This can be a tensorflow model or a path to a tensorflow model.
custom_objects : dict, optional
Any custom objects that need to be loaded with the model, by default None.
"""
super().__init__()
if model_or_path is not None:
custom_objects = custom_objects or {}
if isinstance(model_or_path, (str, os.PathLike)):
self.path = model_or_path
self.model = tf.keras.models.load_model(self.path, custom_objects=custom_objects)
else:
self.path = None
self.model = model_or_path
self.input_schema, self.output_schema = self._construct_schemas_from_model(self.model)
def __getstate__(self) -> dict:
"""Return state of instance when pickled.
Returns
-------
dict
Returns object state excluding model attribute.
"""
return {k: v for k, v in self.__dict__.items() if k != "model"}
@property
def export_name(self):
"""
Provides a clear common english identifier for this operator.
Returns
-------
String
Name of the current class as spelled in module.
"""
return self.__class__.__name__.lower()
[docs] @classmethod
def from_path(cls, path, **kwargs):
return cls.__init__(path, **kwargs)
[docs] def compute_output_schema(
self, input_schema: Schema, col_selector: ColumnSelector, prev_output_schema: Schema = None
) -> Schema:
"""
Use the output schema supplied during object creation.
"""
return self.output_schema
def _construct_schemas_from_model(self, model):
signatures = getattr(model, "signatures", {}) or {}
default_signature = signatures.get("serving_default")
if not default_signature:
# roundtrip saved model to disk to generate signature if it doesn't exist
self._ensure_input_spec_includes_names(model)
with tempfile.TemporaryDirectory() as tmp_dir:
tf_model_path = pathlib.Path(tmp_dir) / "model.savedmodel"
model.save(tf_model_path, include_optimizer=False)
reloaded = tf.keras.models.load_model(tf_model_path)
default_signature = reloaded.signatures["serving_default"]
input_schema = Schema()
for col_name, col in default_signature.structured_input_signature[1].items():
col_schema = ColumnSchema(col_name, dtype=col.dtype.as_numpy_dtype)
if col.shape[1] and col.shape[1] > 1:
col_schema = self._set_list_length(col_schema, col.shape[1])
input_schema.column_schemas[col_name] = col_schema
output_schema = Schema()
for col_name, col in default_signature.structured_outputs.items():
col_schema = ColumnSchema(col_name, dtype=col.dtype.as_numpy_dtype)
if col.shape[1] and col.shape[1] > 1:
col_schema = self._set_list_length(col_schema, col.shape[1])
output_schema.column_schemas[col_name] = col_schema
return input_schema, output_schema
def _ensure_input_spec_includes_names(self, model):
if isinstance(model._saved_model_inputs_spec, dict):
for key, spec in model._saved_model_inputs_spec.items():
if isinstance(spec, tuple):
model._saved_model_inputs_spec[key] = (
tf.TensorSpec(shape=spec[0].shape, dtype=spec[0].dtype, name=key),
tf.TensorSpec(shape=spec[1].shape, dtype=spec[1].dtype, name=key),
)
else:
model._saved_model_inputs_spec[key] = tf.TensorSpec(
shape=spec.shape, dtype=spec.dtype, name=key
)
return model
def _set_list_length(self, col_schema, list_length):
return col_schema.with_dtype(
col_schema.dtype, is_list=True, is_ragged=False
).with_properties({"value_count": {"min": list_length, "max": list_length}})