dist_model_parallel¶
DistributedEmbedding¶
- class distributed_embeddings.python.layers.dist_model_parallel.DistributedEmbedding(*args, **kwargs)[source]¶
Distributed embedding wrapper
This class is a hybrid parallel wrapper around embedding. It handles all to all communication of forward and backward of embedding.
- Parameters:
embeddings (list of keras Embedding layers) – embedding tables to be distributed
strategy (str) – A string indicates how embedding tables are distributed. Choices are [“basic”, “memory_balanced”]. Default “basic”
column_slice_threshold (int or None) – If None, column slice only happen when there are more workers than tables. In that case, column_slice_threshold will be choose automatically so each worker receive at least one slice. If not None, embedding tables with more elements than column_slice_threshold will be divide into N even pieces alone embedded width dimension. N is smallest power of 2 makes each slice smaller than column_slice_threshold. Default None.
row_slice_threshold – Embedding larger than this will be evenly row sliced onto all workers
dp_input (bool) – If True, takes data parallel input, i.e. in shape [local_batch_size x global_num_embeddings]. Otherwise take model parall input in shape [global_batch_size x local_num_embeddings]. Default True.
input_table_map (list or None) – same length list as inputs, map input[i] to table[input_table_map[i]]. None means there are same number of inputs/tables and input[i] map to table[i]. Default None.
- get_weights(all_ranks=False)[source]¶
Returns the current weights of the layer, as NumPy arrays.
This override outputs global weights for all tables. :Parameters: all_ranks (bool) – If true, return weights in all ranks, otherwise only in rank 0.
Default False.
- Returns:
result (list) – List of weight tensors.
- set_weights(weights, chunk=134217728, use_lock=False)[source]¶
Sets the weights of the layer, from NumPy arrays.
- Parameters:
weights (list) – list containing global weights for all table. item in the list can be either numpy array or file path to load from.
chunk (int) – max number of elements per chunk when set weight on GPU by chunks. this will be round to number of rows base on weight shape.
use_lock (bool) – If true, set weights rank by rank in lock step to avoid OOM. Default False.
- Raises:
ValueError – If length of weights does not match length of expected weights.
- distributed_embeddings.python.layers.dist_model_parallel.broadcast_variables(model_vars, root_rank=0)[source]¶
Broadcasts variables from root rank to all other processes in a process set
Replace horovod’s broadcast_variables when running hybrid parallel
See https://horovod.readthedocs.io/en/stable/api.html for more details
- distributed_embeddings.python.layers.dist_model_parallel.DistributedGradientTape(*args, **kwargs)[source]¶
Graident tape that supports hybrid parallel
Replace horovod’s DistributedGradientTape when running hybrid parallel
See https://horovod.readthedocs.io/en/stable/api.html for more details