# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Embedding layers"""
import numpy as np
import tensorflow as tf
from tensorflow.keras import constraints
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras import backend
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.ops.ragged import ragged_tensor
from distributed_embeddings.python.ops import embedding_lookup_ops
class CPUInitializer(tf.keras.initializers.Initializer):
""" initializer wrapper to force one-time init onto CPU, avoiding OOM
"""
def __init__(self, initializer):
self._initializer = initializer
def __call__(self, shape, dtype=None, **kwargs):
with tf.device('/CPU:0'):
res = self._initializer(shape, **kwargs)
return res
[docs]class Embedding(tf.keras.layers.Layer):
"""Turns indices into vectors of fixed size.
Args:
input_dim (int): Size of the vocabulary, i.e. maximum index + 1.
output_dim (int): Length of embedding vectors.
embeddings_initializer: Initializer for the `embeddings`
matrix (see `keras.initializers`).
embeddings_regularizer: Regularizer function applied to
the `embeddings` matrix (see `keras.regularizers`).
embeddings_constraint: Constraint function applied to
the `embeddings` matrix (see `keras.constraints`).
combiner (str): Reduction method, ['sum', 'mean'] or None. Default None.
When combiner is not None, supported input and their respectively output shape are:
N-D `Tensor`: `(d1,...,dn)`, output shape: `(d1,...,dn-1,output_dim)`, N >= 2
2-D `RaggedTensor`: `(batch_size, ragged_dim)`, output shape: `(batch_size, output_dim)`
2-D `SparseTensor`: `(batch_size, max_hotness)`, output shape: `(batch_size, output_dim)`
Embedding picked from last input dimension will be reduced with given combiner.
"""
def __init__(self,
input_dim,
output_dim,
embeddings_initializer='uniform',
embeddings_regularizer=None,
activity_regularizer=None,
embeddings_constraint=None,
combiner=None,
**kwargs):
if 'input_shape' not in kwargs:
kwargs['input_shape'] = (None,)
if input_dim <= 0 or output_dim <= 0:
raise ValueError(
f'Both input_dim and output_dim should be positive, found {input_dim} and {output_dim}')
if (not base_layer_utils.v2_dtype_behavior_enabled() and 'dtype' not in kwargs):
# In TF1, the dtype defaults to the input dtype which is typically int32,
# so explicitly set it to floatx
kwargs['dtype'] = backend.floatx()
# No autocast.
kwargs['autocast'] = False
super().__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.embeddings_initializer = initializers.get(embeddings_initializer)
self.embeddings_initializer_cpu = CPUInitializer(self.embeddings_initializer)
self.embeddings_regularizer = regularizers.get(embeddings_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.embeddings_constraint = constraints.get(embeddings_constraint)
self.combiner = combiner
@tf_utils.shape_type_conversion
def build(self, input_shape): # pylint: disable=unused-argument
self.embeddings = self.add_weight(shape=(self.input_dim, self.output_dim),
initializer=self.embeddings_initializer_cpu,
name='embeddings',
regularizer=self.embeddings_regularizer,
constraint=self.embeddings_constraint,
experimental_autocast=False)
self.built = True
@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if self.combiner is None:
return input_shape + (self.output_dim,)
return input_shape[:-1] + (self.output_dim,)
def call(self, inputs): # pylint: disable=missing-function-docstring
dtype = backend.dtype(inputs)
if dtype not in ['int64', 'int32']:
inputs = tf.cast(inputs, 'int32')
# For needed case, compute output shape and replace leading possible None with -1
out_shape = None
if len(inputs.shape) != 2:
out_shape = [-1] + list(self.compute_output_shape(inputs.shape))[1:]
# check for unsupported cases and reshape non-2D dense inputs
if isinstance(inputs, ragged_tensor.RaggedTensor):
if len(inputs.shape) > 2:
raise ValueError('Ragged input should be 2D. Nested ragged is not supported.')
else:
if len(inputs.shape) == 1:
if self.combiner is not None:
raise ValueError('1D input with combiner is ambiguous. Please create batch dimension.')
inputs = tf.reshape(inputs, [-1, 1])
if len(inputs.shape) > 2:
inputs = tf.reshape(inputs, [-1, inputs.shape[-1]])
out = embedding_lookup_ops.embedding_lookup(self.embeddings, inputs, combiner=self.combiner)
if out_shape is not None:
out = tf.reshape(out, out_shape)
return out
[docs] def get_config(self): # pylint: disable=missing-function-docstring
config = {
'input_dim': self.input_dim,
'output_dim': self.output_dim,
'embeddings_initializer': initializers.serialize(self.embeddings_initializer),
'embeddings_regularizer': regularizers.serialize(self.embeddings_regularizer),
'activity_regularizer': regularizers.serialize(self.activity_regularizer),
'embeddings_constraint': constraints.serialize(self.embeddings_constraint),
'combiner': self.combiner
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
[docs] @classmethod
def from_config(cls, config):
"""Creates a layer from its config.
Overriding this to enable instatiating fast embedding from keras embedding configs
"""
config.pop('mask_zero', None)
config.pop('input_length', None)
return super().from_config(config)
class ConcatOneHotEmbedding(tf.keras.layers.Layer):
"""Concatenated one hot embedding
Args:
feature_sizes (list): A list of integer indicating number of features of each embedding table
embedding_width (int): Width of embedding vector
"""
def __init__(self, feature_sizes, embedding_width):
super().__init__(dtype=tf.float32)
self.embedding_width = embedding_width
self._offsets_np = np.array([0] + feature_sizes).cumsum()
self.params = self.add_weight("params",
shape=[self._offsets_np[-1], self.embedding_width],
dtype=tf.float32)
self.offsets = tf.constant(self._offsets_np, dtype=tf.int32)
def call(self, inputs):
assert inputs.shape[1] == len(self.offsets) - 1
offset_indices = inputs + self.offsets[:-1]
embedding_out = tf.gather(params=self.params, indices=offset_indices, axis=None)
return embedding_out
# pylint: disable=missing-class-docstring
[docs]class IntegerLookup(tf.keras.layers.Layer):
"""
A preprocessing layer which maps integer features to contiguous ranges.
Vocabulary is generated on the fly, static vocabulary and adapt() will be supported.
Partially support features of tf.keras.layers.IntegerLookup.
Frequency of keys are counted when GPU algorithm is used.
"""
def __init__(self, max_tokens, use_gpu=True):
super().__init__()
max_tokens = int(max_tokens)
self.capacity = max_tokens + 1
self.use_gpu = use_gpu
if self.use_gpu:
# TODO(deyuf): run some benchmark to make sure 32bit here does not cause problem
self.count = tf.Variable(tf.zeros((max_tokens + 1,), tf.uint32), trainable=False)
# Reserve first index for oov token. The inplementation doesn't require a oov token like -1
# Alternatively, we could define an oov token and insert with 0.
self.count.scatter_update(tf.IndexedSlices(1, 0))
# TODO: explore adjusting table size on the fly
# TODO: table init on first lookup now. should separate init op out.
# Initialize the table so it can be used across ops. Since all cucollection need is pointer,
# we don't need a "tableop" like native to return resource handle.
# 1.5x load factor, 2x keys + values
self.table = tf.Variable(tf.zeros((2 * int(1.5 * self.capacity),), tf.int64), trainable=False)
else:
with tf.device('/CPU'):
self.table = tf.lookup.experimental.DenseHashTable(key_dtype=tf.int64,
value_dtype=tf.int64,
default_value=0,
empty_key=-2,
deleted_key=-3)
# TODO(deyuf): benchmark code that handles max_token for cpu table
self.table.insert(-1, 0)
self.num_empty_slot = max_tokens
def call(self, inputs):
if self.use_gpu:
return embedding_lookup_ops.integer_lookup(self.table, self.count, inputs, self.capacity)
# This is efficient on cpu, especially with power law distribution data
with tf.device('/CPU'):
input_shape = tf.shape(inputs)
inputs = tf.reshape(inputs, [-1])
if self.num_empty_slot > 0:
keys, _ = tf.unique(inputs)
vals = self.table.lookup(keys)
new_keys = tf.gather(keys, tf.reshape(tf.where(vals <= 0), [-1]))[:self.num_empty_slot]
num_insert = tf.shape(new_keys, out_type=tf.int64)[0]
self.num_empty_slot -= num_insert
self.table.insert(new_keys, tf.range(self.table.size(), self.table.size() + num_insert))
return tf.reshape(self.table.lookup(inputs), input_shape)
def get_vocabulary(self):
# TODO: may need a new api as gpu lookup may not be contiguous
if self.use_gpu:
# just to be sure, we sort the index returned by where so key/value pair remain together
used_ids = tf.sort(tf.reshape(tf.where(self.table != -1), [-1]))
kv = tf.gather(self.table, used_ids)
# split k,v
keys = kv[0::2]
vals = kv[1::2]
# filter out oov keys that mapped to 0
non_oov_ids = tf.sort(tf.reshape(tf.where(vals > 0), [-1]))
keys = tf.gather(keys, non_oov_ids)
vals = tf.gather(vals, non_oov_ids)
# sort output dict into lookup value order
sort_order = tf.argsort(vals)
keys = tf.gather(keys, sort_order)
return [-1] + keys.numpy().tolist()
keys, _ = self.table.export()
unique_keys, _ = tf.unique(tf.reshape(keys, [-1]))
# remove reserved empty and deleted key
unique_keys = tf.gather(unique_keys, tf.reshape(tf.where(unique_keys != -2), [-1]))
unique_keys = tf.gather(unique_keys, tf.reshape(tf.where(unique_keys != -3), [-1]))
unique_vals = self.table.lookup(unique_keys)
sort_order = tf.argsort(unique_vals)
unique_keys = tf.gather(unique_keys, sort_order)
return unique_keys.numpy().tolist()