Source code for sparse_operation_kit.saver.Saver

 Copyright (c) 2021, NVIDIA CORPORATION.
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 See the License for the specific language governing permissions and
 limitations under the License.
from sparse_operation_kit import kit_lib
from sparse_operation_kit.core.embedding_layer_handle import GraphKeys
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops

# TODO: make it inherit from trackable???
[docs]class Saver(object): """ Abbreviated as ``sok.Saver()``. """ def __init__(self): # TODO: how to get all emb_var from Model??? pass def __call__(self): pass
[docs] def dump_to_file(self, embedding_variable, filepath): """ This function is used to save the specified embedding variables to host file. When multiple CPU processes is used, this function must be called within each CPU processes. Parameters ---------- embedding_variable: sok.EmbeddingVariable, tf.DistributedVariable The variable from embedding layer which needs to be dumped to file. filepath: string The directory where the parameters will be dumped to. Returns ------- status: tf.Tensor If this op executed successfully, then 'OK' will be returned. """ # TODO: check whether embedding_variable is an instance of DistributedVariable if hasattr(embedding_variable, "emb_handle"): # horovod branch return kit_lib.dump_to_file(embedding_variable.emb_handle, filepath) else: # strategy branch return kit_lib.dump_to_file(embedding_variable.values[0].emb_handle, filepath)
[docs] def restore_from_file(self, embedding_variable, filepath): """ This function is used to restore dumped parameters to the specified embedding variable. When multiple CPU processes is used, this function must be called within each CPU processes. Parameters ---------- embedding_variable: sok.EmbeddingVariable, tf.DistributedVariable The embedding variable which needs to be restored from file. filepath: string The directory where the parameters will be restored from. Returns ------- status: tf.Tensor If this op executed successfully, then 'OK' will be returned. """ if kit_lib.in_tensorflow2(): context = ops.NullContextmanager initializers = None else: context = ops.control_dependencies # in case the embedding layer has not been created collections = ops.get_collection(GraphKeys.SparseOperationKitEmbeddingLayers) initializers = [collect.initializer for collect in collections] with context(initializers): if hasattr(embedding_variable, "emb_handle"): # horovod branch return kit_lib.restore_from_file(embedding_variable.emb_handle, filepath) else: # strategy branch return kit_lib.restore_from_file(embedding_variable.values[0].emb_handle, filepath)
[docs] def load_embedding_values(self, embedding_variable, tensors): """ This function is used to assign embedding_variable's value with tf.Tensors. When multiple CPU processes is used, this function must be called within each CPU processes. Parameters ---------- embedding_variable: sok.EmbeddingVariable, tf.DistributedVariable Which embedding_variable's value will be assigned. tensors: tf.Tensor, list of tf.Tensor, tuple of tf.Tensor Each tf.Tensor must be 2-rank and the shape must be `[None, embedding_vec_size]`, where the `embedding_vec_size` must be equal to that of embedding_variable's. All tf.Tensors make up to a big tensor, which just like they are stacked. For example: `[tf.Tensor(shape=(bs_0, embedding_vec_size)), tf.Tensor(shape=(bs_1, embedding_vec_size)),\ tf.Tensor(shape=(bs_2, embedding_vec_size))]` will be treated as `tf.Tensor(shape=(bs_0 + bs_1 + bs_2, embedding_vec_size))`. Returns ------- status: tf.Tensor If this op executed successfully, then 'OK' will be returned. """ if kit_lib.in_tensorflow2(): context = ops.NullContextmanager initializers = None else: context = ops.control_dependencies # in case the embedding layer has not been created collections = ops.get_collection(GraphKeys.SparseOperationKitEmbeddingLayers) initializers = [collect.initializer for collect in collections] if isinstance(tensors, list) or isinstance(tensors, tuple): # stack those tensors along dim-0 tensors = array_ops.concat(tensors, axis=0) with context(initializers): if hasattr(embedding_variable, "emb_handle"): # horovod branch return kit_lib.load_embedding_values(embedding_variable.emb_handle, tensors) else: # strategy branch return kit_lib.load_embedding_values( embedding_variable.values[0].emb_handle, tensors )