class distributed_embeddings.python.layers.dist_model_parallel.DistributedEmbedding(*args, **kwargs)[source]

Distributed embedding wrapper

This class is a hybrid parallel wrapper around embedding. It handles all to all communication of forward and backward of embedding.

  • embeddings (list of keras Embedding layers) – embedding tables to be distributed

  • strategy (str) – A string indicates how embedding tables are distributed. Choices are [“basic”, “memory_balanced”]. Default “basic”

  • column_slice_threshold (int or None) – If not None, embedding tables with more elements than column_slice_threshold will be divide into N even pieces alone embedded width dimension. N is smallest power of 2 makes each slice smaller than column_slice_threshold. Default None.

  • row_slice (TBD) – Describe how which embedding needs to be row sliced

  • dp_input (bool) – If True, takes data parallel input, i.e. in shape [local_batch_size x global_num_embeddings]. Otherwise take model parall input in shape [global_batch_size x local_num_embeddings]. Default True.

  • input_table_map (list or None) – same length list as inputs, map input[i] to table[input_table_map[i]]. None means there are same number of inputs/tables and input[i] map to table[i]. Default None.


Returns the current weights of the layer, as NumPy arrays.

This override outputs global weights for all tables. :Parameters: all_ranks (bool) – If true, return weights in all ranks, otherwise only in rank 0.

Default False.


result (list) – List of weight tensors.

set_weights(weights, chunk=134217728, use_lock=False)[source]

Sets the weights of the layer, from NumPy arrays.

  • weights (list) – list containing global weights for all table. item in the list can be either numpy array or file path to load from.

  • chunk (int) – max number of elements per chunk when set weight on GPU by chunks. this will be round to number of rows base on weight shape.

  • use_lock (bool) – If true, set weights rank by rank in lock step to avoid OOM. Default False.


ValueError – If length of weights does not match length of expected weights.

distributed_embeddings.python.layers.dist_model_parallel.broadcast_variables(model_vars, root_rank=0)[source]

Broadcasts variables from root rank to all other processes in a process set

Replace horovod’s broadcast_variables when running hybrid parallel

See https://horovod.readthedocs.io/en/stable/api.html for more details

distributed_embeddings.python.layers.dist_model_parallel.DistributedGradientTape(*args, **kwargs)[source]

Graident tape that supports hybrid parallel

Replace horovod’s DistributedGradientTape when running hybrid parallel

See https://horovod.readthedocs.io/en/stable/api.html for more details