Source code for sparse_operation_kit.experiment.dynamic_variable

#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.resource_variable_ops import ResourceVariable, variable_accessed
from tensorflow.python.eager import context
from tensorflow.python.ops import resource_variable_ops

from sparse_operation_kit.experiment import raw_ops as dynamic_variable_ops
from sparse_operation_kit.experiment.communication import num_gpus


[docs]class DynamicVariable(ResourceVariable): """ Abbreviated as ``sok.experiment.DynamicVariable``. A variable that allocates memory dynamically. Parameters ---------- dimension: int The last dimension of this variable(that is, the embedding vector size of embedding table). initializer: string a string to specify how to initialize this variable. Currently, only support "random" or string of a float value(meaning const initializer). Default value is "random". key_type: dtype specify the data type of indices. Unlike the static variable of tensorflow, this variable is dyanmically allocated and contains a hash table inside it. So the data type of indices must be specified to construct the hash table. Default value is tf.int64. dtype: dtype specify the data type of values. Default value is tf.float32. Example ------- .. code-block:: python import numpy as np import tensorflow as tf import horovod.tensorflow as hvd from sparse_operation_kit import experiment as sok v = sok.DynamicVariable(dimension=3, initializer="13") print("v.shape:", v.shape) print("v.size:", v.size) indices = tf.convert_to_tensor([0, 1, 2**40], dtype=tf.int64) embedding = tf.nn.embedding_lookup(v, indices) print("embedding:", embedding) print("v.shape:", v.shape) print("v.size:", v.size) """ def __init__( self, dimension, initializer=None, var_type=None, name=None, constraint=None, trainable=True, key_type=None, dtype=None, mode=None, ): self._key_type = key_type if key_type is not None else tf.int64 self._handle_dtype = dtype if dtype is not None else tf.float32 self._dimension = dimension self._indices = None self._mode = mode self._base = super(DynamicVariable, self) self._base.__init__( initial_value=[[0.0] * dimension], trainable=trainable, name="DynamicVariableBuffer", dtype=self._handle_dtype, constraint=constraint, distribute_strategy=None, synchronization=None, aggregation=None, shape=[None, dimension], ) with ops.init_scope(): name = "DynamicVariable" if name is None else name with ops.name_scope(name) as name_scope: self._dummy_name = ops.name_from_scope_name(name_scope) if context.executing_eagerly(): self._dummy_name = "%s_%d" % (name, ops.uid()) with ops.NullContextmanager(): shape = [None, dimension] initializer = "" if initializer is None else initializer var_type = "dynamic" if var_type is None else var_type handle = dynamic_variable_ops.dummy_var_handle( container="DummyVariableContainer", shared_name=self._dummy_name, key_type=self._key_type, dtype=self._handle_dtype, shape=shape, ) if type(initializer) is str: init_op = dynamic_variable_ops.dummy_var_initialize( handle, initializer=initializer, var_type=var_type, unique_name=self._dummy_name, key_type=self._key_type, dtype=self._handle_dtype, ) else: with tf.control_dependencies([initializer._initializer_op]): initial_val = initializer.read_value() init_op = dynamic_variable_ops.dummy_var_initialize( handle, initializer=initial_val, var_type=var_type, unique_name=self._dummy_name, key_type=self._key_type, dtype=self._handle_dtype, ) # TODO: Add is_initialized_op # is_initialized_op = ops.convert_to_tensor(True) self._tf_handle = self._handle self._dummy_handle = handle # Note that the default handle will be sok's handle self._handle = self._dummy_handle self._initializer_op = tf.group([self._initializer_op, init_op]) # self._is_initialized_op = tf.group([self._is_initialized_op, is_initialized_op]) handle_data = ( resource_variable_ops.cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData() ) handle_data.is_set = True handle_data.shape_and_type.append( resource_variable_ops.cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType( shape=self.shape.as_proto(), dtype=self.dtype.as_datatype_enum ) ) resource_variable_ops._set_handle_shapes_and_types( self._handle, handle_data, graph_mode=False if context.executing_eagerly() else True ) def is_static(self): return self._handle is self._tf_handle def to_static(self, indices): if not self.is_static() and self._indices is None: buffer = self.sparse_read(indices) self._indices = indices self._handle = self._tf_handle return self.assign(buffer) else: raise RuntimeError("to_static() must be called in dynamic mode.") def to_dynamic(self): if self.is_static(): buffer = self.read_value() sparse_delta = ops.IndexedSlices(buffer, self._indices, self.shape) self._indices = None self._handle = self._dummy_handle return self.scatter_update(sparse_delta) else: raise RuntimeError("to_dynamic() must be called in static mode.") def __repr__(self): if self.is_static(): return self._base.__repr__() return "<sok.DynamicVariable '%s' shape=%s dtype=%s>" % ( self._dummy_name, self.shape, self.dtype.name, ) @property def size(self): return dynamic_variable_ops.dummy_var_shape( self._dummy_handle, key_type=self._key_type, dtype=self._handle_dtype ) @property def indices(self): return self._indices @property def dimension(self): return self._dimension @property def key_type(self): return self._key_type @property def handle_dtype(self): return self._handle_dtype @property def target_gpu(self): if self._mode is not None and self._mode[: len("localized")] == "localized": target_gpu = int(self._mode.split(":")[1]) if target_gpu >= num_gpus(): raise RuntimeError( "There are only %d GPU(s), cannot put embedding table on %dth(zero-indexed) GPU." % (num_gpus(), target_gpu) ) return target_gpu return -1 @property def num_gpus(self): return num_gpus() def key_map(self, indices): return indices # ------------------------------------------------------------------------- # Methods supported both in static mode and dynamic mode # ------------------------------------------------------------------------- def sparse_read(self, indices, name=None): if self.is_static(): return self._base.sparse_read(indices, name) variable_accessed(self) if indices.dtype == tf.int32: indices = tf.cast(indices, tf.int64) return dynamic_variable_ops.dummy_var_sparse_read( self._dummy_handle, indices, dtype=self._handle_dtype ) def scatter_sub(self, sparse_delta, use_locking=False, name=None): if self.is_static(): return self._base.scatter_sub(sparse_delta, use_locking, name) if not isinstance(sparse_delta, ops.IndexedSlices): raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) return dynamic_variable_ops.dummy_var_scatter_add( self._dummy_handle, sparse_delta.indices, ops.convert_to_tensor(-sparse_delta.values, self.dtype), ) def scatter_add(self, sparse_delta, use_locking=False, name=None): if self.is_static(): return self._base.scatter_add(sparse_delta, use_locking, name) if not isinstance(sparse_delta, ops.IndexedSlices): raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) return dynamic_variable_ops.dummy_var_scatter_add( self._dummy_handle, sparse_delta.indices, ops.convert_to_tensor(sparse_delta.values, self.dtype), ) def scatter_update(self, sparse_delta, use_locking=False, name=None): if self.is_static(): return self._base.scatter_update(sparse_delta, use_locking, name) if not isinstance(sparse_delta, ops.IndexedSlices): raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) return dynamic_variable_ops.dummy_var_scatter_update( self._dummy_handle, sparse_delta.indices, ops.convert_to_tensor(sparse_delta.values, self.dtype), ) # ------------------------------------------------------------------------- # Methods not supported both in static mode and dynamic mode # ------------------------------------------------------------------------- def __deepcopy__(self, *args, **kwargs): raise NotImplementedError("__deepcopy__() is not supported.") def __reduce__(self, *args, **kwargs): raise NotImplementedError("__reduce__() is not supported.") def to_proto(self, *args, **kwargs): raise NotImplementedError("to_proto() is not supported.") @staticmethod def from_proto(*args, **kwargs): raise NotImplementedError("from_proto() is not supported.") def set_shape(self, *args, **kwargs): raise NotImplementedError("set_shape() is not supported.") # ------------------------------------------------------------------------- # Methods only supported in static mode # ------------------------------------------------------------------------- def is_initialized(self): if self.is_static(): return self._base.is_initialized() raise NotImplementedError("is_initialized() is not supported in dynamic mode.") def _read_variable_op(self): if self.is_static(): return self._base._read_variable_op() raise NotImplementedError("_read_variable_op() is not supported in dynamic mode.") def value(self): if self.is_static(): return self._base.value() raise NotImplementedError("value() is not supported in dynamic mode.") def _dense_var_to_tensor(self, *args, **kwargs): if self.is_static(): return self._base._dense_var_to_tensor(*args, **kwargs) raise NotImplementedError("_dense_var_to_tensor() is not supported in dynamic mode.") def _gather_saveables_for_checkpoint(self): if self.is_static(): return self._base._gather_saveables_for_checkpoint() raise NotImplementedError( "_gather_saveables_for_checkpoint() is not supported in dynamic mode." ) def gather_nd(self, *args, **kwargs): if self.is_static(): return self._base.gather_nd(*args, **kwargs) raise NotImplementedError("gather_nd() is not supported in dynamic mode.") def assign_add(self, *args, **kwargs): if self.is_static(): return self._base.assign_add(*args, **kwargs) raise NotImplementedError("assign_add() is not supported in dynamic mode.") def assign(self, *args, **kwargs): if self.is_static(): return self._base.assign(*args, **kwargs) raise NotImplementedError("assign() is not supported in dynamic mode.") def scatter_max(self, *args, **kwargs): if self.is_static(): return self._base.scatter_max(*args, **kwargs) raise NotImplementedError("scatter_max() is not supported in dynamic mode.") def scatter_min(self, *args, **kwargs): if self.is_static(): return self._base.scatter_min(*args, **kwargs) raise NotImplementedError("scatter_min() is not supported in dynamic mode.") def scatter_mul(self, *args, **kwargs): if self.is_static(): return self._base.scatter_mul(*args, **kwargs) raise NotImplementedError("scatter_mul() is not supported in dynamic mode.") def scatter_dim(self, *args, **kwargs): if self.is_static(): return self._base.scatter_dim(*args, **kwargs) raise NotImplementedError("scatter_dim() is not supported in dynamic mode.") def batch_scatter_update(self, *args, **kwargs): if self.is_static(): return self._base.batch_scatter_update(*args, **kwargs) raise NotImplementedError("batch_scatter_update() is not supported in dynamic mode.") def scatter_nd_sub(self, *args, **kwargs): if self.is_static(): return self._base.scatter_nd_sub(*args, **kwargs) raise NotImplementedError("scatter_nd_sub() is not supported in dynamic mode.") def scatter_nd_update(self, *args, **kwargs): if self.is_static(): return self._base.scatter_nd_update(*args, **kwargs) raise NotImplementedError("scatter_nd_update() is not supported in dynamic mode.") def _strided_slice_assign(self, *args, **kwargs): if self.is_static(): return self._base._strided_slice_assign(*args, **kwargs) raise NotImplementedError("_strided_slice_assign() is not supported in dynamic mode.") def __int__(self, *args, **kwargs): if self.is_static(): return self._base.__int__(*args, **kwargs) raise NotImplementedError("__int__() is not supported in dynamic mode.")
@tf.RegisterGradient("DummyVarSparseRead") def _SparseReadGrad(op, grad): """Gradient for sparse_read.""" handle = op.inputs[0] indices = op.inputs[1] key_type = op.get_attr("key_type") dtype = op.get_attr("dtype") variable_shape = dynamic_variable_ops.dummy_var_shape(handle, key_type=key_type, dtype=dtype) size = array_ops.expand_dims(array_ops.size(indices), 0) values_shape = array_ops.concat([size, variable_shape[1:]], 0) grad = array_ops.reshape(grad, values_shape) indices = array_ops.reshape(indices, size) return (ops.IndexedSlices(grad, indices, variable_shape), None)
[docs]def export(var): """ Abbreviated as ``sok.experiment.export``. Export the indices and value tensor from the given variable. Parameters ---------- var: sok.DynamicVariable The variable to extract indices and values. Returns ------- indices: tf.Tensor The indices of the given variable. values: tf.Tensor the values of the given variable. """ if isinstance(var, DynamicVariable): indices, values = dynamic_variable_ops.dummy_var_export( var.handle, key_type=var.key_type, dtype=var.handle_dtype ) with tf.device("CPU"): indices = tf.identity(indices) values = tf.identity(values) return indices, values
[docs]def assign(var, indices, values): """ Abbreviated as ``sok.experiment.assign``. Assign the indices and value tensor to the target variable. Parameters ---------- var: sok.DynamicVariable The target variable of assign. indices: tf.Tensor indices to be assigned to the variable. values: tf.Tensor values to be assigned to the variable Returns ------- variable: sok.DynamicVariable """ if isinstance(var, DynamicVariable): return dynamic_variable_ops.dummy_var_assign(var.handle, indices, values)