Source code for merlin_standard_lib.utils.proto_utils

# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
from typing import TypeVar

import betterproto
from betterproto import Message as BetterProtoMessage
from google.protobuf import json_format, text_format
from google.protobuf.message import Message as ProtoMessage

ProtoMessageType = TypeVar("ProtoMessageType", bound=BetterProtoMessage)


[docs]def has_field(self, field_name): return betterproto.serialized_on_wire(getattr(self, field_name))
[docs]def copy_better_proto_message(better_proto_message: ProtoMessageType, **kwargs) -> ProtoMessageType: output = better_proto_message.__class__().parse(bytes(better_proto_message)) for key, val in kwargs.items(): setattr(output, key, val) return output
[docs]def better_proto_to_proto_text( better_proto_message: BetterProtoMessage, message: ProtoMessage ) -> str: message.ParseFromString(bytes(better_proto_message)) return text_format.MessageToString(message)
[docs]def proto_text_to_better_proto( better_proto_message: ProtoMessageType, path_proto_text: str, message: ProtoMessage ) -> ProtoMessageType: proto_text = path_proto_text if os.path.isfile(proto_text): with open(path_proto_text, "r") as f: proto_text = f.read() proto = text_format.Parse(proto_text, message) # This is a hack because as of now we can't parse the Any representation. # TODO: Fix this. d = json_format.MessageToDict(proto) for f in d["feature"]: if "extraMetadata" in f["annotation"]: # type: ignore extra_metadata = f["annotation"].pop("extraMetadata") # type: ignore f["annotation"]["comment"] = [json.dumps(extra_metadata[0]["value"])] # type: ignore json_str = json_format.MessageToJson(json_format.ParseDict(d, message)) return better_proto_message.__class__().from_json(json_str)