Source code for distributed_embeddings.python.layers.dist_model_parallel

# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Distributed Embedding layers and utils"""
import types
import math
import numpy as np
import tensorflow as tf
from tensorflow.keras import initializers
from tensorflow.python.keras.utils import tf_utils
import horovod
import horovod.tensorflow as hvd
import horovod.tensorflow.keras as hvd_keras
from distributed_embeddings.python.ops.embedding_lookup_ops import read_var_no_copy
from .embedding import Embedding


class ConcatInitializer(tf.keras.initializers.Initializer):
  """ initializer wrapper to handle automatic concat table on first dimension
  """

  def __init__(self, initializer, sizes):
    self._initializer = initializer
    self.sizes = sizes

  def __call__(self, shape, **kwargs):
    weights = [self._initializer([size, shape[1]], **kwargs) for size in self.sizes]
    weights = tf.concat(weights, axis=0)
    return weights


def _get_shape(tensor):
  """Return shape of tensor

  Static shape is not always available, in which case we use tf.shape to get dynamic shape

  Args:
      tensor (Tensor): Input tensor

  Returns:
      tf.Shape
  """
  if tensor.shape is not None and None not in tensor.shape:
    return tensor.shape
  return tf.shape(tensor)


@tf.custom_gradient
def grouped_reducescatter_unscaled(inputs):
  outputs = hvd.grouped_reducescatter(inputs, op=hvd.Sum)

  def grad(*upstream):
    return hvd.grouped_allgather(upstream)

  return outputs, grad


class DistEmbeddingStrategy():
  """Distributed embedding strategy

  Args:
    embeddings (list of Embedding): list of unbuilt Embedding layers globally
    strategy (str): A string indicates how embedding tables are distributed.
        Choices are [“basic”, “memory_balanced”]. Default "basic"
    input_table_map (list or None): A list of table ids mapping each input to a table, i.e.,
        `input[i]` map to `table[input_table_map[i]]`. None means there are same number of
        inputs/tables and `input[i]` map to `table[i]`. Default None.

  Attributes:
    strategy: string indicates how embedding tables are distributed.
    column_slice_threshold: desired upper bound of elements count in each slice.
    column_slice_threshold: desired lower bound where table larger than this will be row sliced.
    sliced_out_ranges: list of [output_pos, num_slices] for each input, used to merged outputs.
    table_groups: lists of table ids. group 0 runs dp, group 1 do column slice with table parallel,
                  group 2 runs row_slice onto all workers.
    input_groups: input ids coresponding to table grouping.
    map_groups: input to table map within each group.
    rev_group_ids: list used to reorder grouped output back into flat input order.
    row_sliced_configs: configs that's been row sliced.
    row_inputs_offsets: add to row slice input so inputs except certain range will go OOB.
    input_ids_list: nested list, contain input index list in rank order. use for dp_input == False.
    local_maps: nested list, contain per rank local input to table map.
    local_configs: nested list, contain per rank lists of local table configs.
    local_input_offsets: nested list, contain per rank offsets to form input for concat embedding.
    local_weight_offsets: nested list, contain per rank weight internal offsets get/set_weights.
    local_group_list: nested list, contain per rank concat grouping info for get/set_weights.
    table_ids: nested list, contain per rank table ids for get/set_weights.
    widths_list_flat: list of all output width, before merging slices and in worker order
    rev_tp_ids: list used to reorder table parallel output back into flat tp input order.
  """

  def __init__(self,
               embeddings,
               world_size,
               strategy="basic",
               input_table_map=None,
               column_slice_threshold=None,
               row_slice_threshold=None,
               data_parallel_threshold=None):
    # code in DMP to skip hvd call in single process case may assume "basic"
    self.strategy = "basic" if world_size == 1 else strategy
    # column_slice can be used to enable more table concat, so keep it in single process
    self.column_slice_threshold = column_slice_threshold
    self.row_slice_threshold = row_slice_threshold
    self.data_parallel_threshold = data_parallel_threshold
    self.global_configs = [e.get_config() for e in embeddings]
    # Insert layer type information to config dicts
    for config, embedding in zip(self.global_configs, embeddings):
      config['layer_type'] = type(embedding)
    if input_table_map is None:
      input_table_map = list(range(len(embeddings)))

    # separated table ids into groups for different strat
    self.table_groups = self.init_table_groups(self.global_configs)
    # input ids and map. rev_group_ids here to reverse grouped call back to input order
    self.input_groups, self.map_groups, self.rev_group_ids = self.init_input_and_map_groups(
        self.table_groups, input_table_map)

    # 1. handle data parallel
    self.dp_configs = [self.global_configs[idx] for idx in self.table_groups[0]]

    # 2. handle row slicing
    if self.table_groups[2]:
      self.row_sliced_configs, self.row_inputs_offsets = self.create_row_sliced_configs(
          [self.global_configs[idx] for idx in self.table_groups[2]], world_size)

    # 3. handle column slicing and table parallel
    if not self.table_groups[1]:
      return
    # Create (maybe) sliced configs
    sliced_configs, self.sliced_out_ranges = self.create_col_sliced_configs(
        [self.global_configs[idx] for idx in self.table_groups[1]], world_size,
        self.column_slice_threshold, self.map_groups[1])

    # Apply strategy and save nested list containing table indices by rank
    table_ids = self.apply_stragety(self.strategy, world_size, sliced_configs)

    # Following are ALL nested lists holding info for distributing embeddings, ordered by rank
    self.input_ids_list = []
    self.local_maps = []
    self.local_configs = []
    self.local_input_offsets = []
    self.local_weight_offsets = []
    self.local_group_list = []
    self.table_ids = []

    # Each worker loop over all rank to get global view of strategy
    for rank_table_ids in table_ids:
      # first merge different shards of same table that ends up on same rank
      rank_table_ids, rank_configs = self._merge_slices(rank_table_ids, sliced_configs)
      self.table_ids.append(rank_table_ids)

      # calculate local input ids and map from this rank's table_ids and global input map
      rank_input_ids, rank_input_map = [], []
      for m, table_idx in enumerate(rank_table_ids):
        for k, mapped_idx in enumerate(self.map_groups[1]):
          if table_idx == mapped_idx:
            rank_input_ids.append(k)
            rank_input_map.append(m)

      # concat eligible tables then adjust local config and map
      rank_configs, rank_input_map, input_offsets, group, weight_offsets = self._create_concat(
          rank_configs, rank_input_map)

      # save results to global nested list
      self.input_ids_list.append(rank_input_ids)
      self.local_configs.append(rank_configs)
      self.local_maps.append(rank_input_map)
      self.local_input_offsets.append(input_offsets)
      self.local_group_list.append(group)
      self.local_weight_offsets.append(weight_offsets)

    # create a flatten list contain table widths, in worker order, used for slice after alltoall
    # This is fast but might switch to not use this to support non-2D output(no local combiner)
    self.widths_list_flat = []
    for config, input_map in zip(self.local_configs, self.local_maps):
      self.widths_list_flat += [config[m]['output_dim'] for m in input_map]

    # List of indices to shuffle worker ordered embedding outputs back to original order
    worker_order_input_ids = [item for sublist in self.input_ids_list for item in sublist]
    self.rev_tp_ids = [
        index
        for _, index in sorted(zip(worker_order_input_ids, range(len(worker_order_input_ids))))
    ]

  # below are the methods to divide table into groups and adjust input and input map accordingly
  def init_table_groups(self, configs):
    # We want to support data parallel, table parallel, column slice and row slice
    # Couple assumptions:
    # - strat applied by size in above order, meaning small table run dp -> large table run row slice
    # - currently only apply one of above to any table. it may make sense to mix row/col slice in future
    # - because communication pattern is different, we run 3 separate groups calls
    # - non-symmetric table parallel is only applied to column sliced group
    num_elems = [config['input_dim'] * config['output_dim'] for config in configs]
    dp, col, row = [], [], []
    for i, num_elem in enumerate(num_elems):
      if self.data_parallel_threshold and num_elem <= self.data_parallel_threshold:
        dp.append(i)
      elif self.row_slice_threshold and num_elem >= self.row_slice_threshold:
        row.append(i)
      else:
        col.append(i)
    return [dp, col, row]

  def init_input_and_map_groups(self, table_groups, input_table_map):
    dp, col, row = table_groups
    # pick out inputs for each group
    dp_in, col_in, row_in = [], [], []
    dp_map, col_map, row_map = [], [], []
    for i, idx in enumerate(input_table_map):
      if idx in dp:
        dp_in.append(i)
        dp_map.append(dp.index(idx))
      elif idx in col:
        col_in.append(i)
        col_map.append(col.index(idx))
      elif idx in row:
        row_in.append(i)
        row_map.append(row.index(idx))
      else:
        raise ValueError("Wrong input initializing input/map groups.")
    flat_input_ids = dp_in + col_in + row_in
    reverse_ids = [index for _, index in sorted(zip(flat_input_ids, range(len(flat_input_ids))))]
    return [dp_in, col_in, row_in], [dp_map, col_map, row_map], reverse_ids

  def maybe_slice_table_column(self, orig_config, column_slice_threshold, world_size):
    """Column slice a embedding config if size exceed column_slice_threshold.
    Assume N is smallest power of 2 so that when evenly slice original table into N tables,
    each have less than column_slice_threshold elements.
    So final number of slices will be min(N, world_size, table_width).
    Args:
      orig_config (dict): embedding layer config to create slices from
      column_slice_threshold (int or None): desired upper bound of elements count in each slice
      world_size (int): number of total model parallel worker
    Returns:
      sliced_config (list): list of embedding layer config that concat into original config
    """
    if column_slice_threshold is None:
      column_slice_threshold = float('inf')
    table_size = orig_config['input_dim'] * orig_config['output_dim']
    num_slices = 1
    while table_size > column_slice_threshold:
      num_slices *= 2
      table_size /= 2
    if num_slices == 1:
      return [orig_config.copy()]
    num_slices = min(num_slices, world_size, orig_config['output_dim'])
    column_per_slice = orig_config['output_dim'] // num_slices
    remainder = orig_config['output_dim'] % num_slices
    sliced_config = []
    for i in range(num_slices):
      config = orig_config.copy()
      config['output_dim'] = column_per_slice
      if i < remainder:
        config['output_dim'] += 1
      sliced_config.append(config)
    return sliced_config

  def create_col_sliced_configs(self, global_col_configs, world_size, column_slice_threshold,
                                input_table_map):
    """Create column sliced configs from global configs.
    This function also calculate ranges of data parallel output needs concat due to this slice.
    Args:
      global_col_configs (list): selected configs for doing column slice
      world_size (int): number of model parallel workers
      column_slice_threshold (int or None): desired upper bound of elements count in each slice
      input_table_map (list): A list of table ids mapping each input to a table
    Returns:
      sliced_configs (list): same length as global configs. each element is a list represent sliced
    form of global config at the same position.
      sliced_out_ranges (list): each element is list of 2 integers, representing output ranges need
    to be concatenated to re-form output due to above slice.
    """
    # less table than worker, we try our best to slice into worker count slices(may go over)
    if column_slice_threshold is None:
      table_sizes = [config['input_dim'] * config['output_dim'] for config in global_col_configs]
      while world_size > len(table_sizes):
        table_sizes.sort()
        column_slice_threshold = table_sizes[-1] - 1
        cur_max_size = table_sizes.pop(-1)
        table_sizes += [cur_max_size // 2, cur_max_size // 2]

    sliced_configs = []
    for col_config in global_col_configs:
      maybe_sliced_config = self.maybe_slice_table_column(col_config, column_slice_threshold,
                                                          world_size)
      sliced_configs.append(maybe_sliced_config)
    # figure out ranges of output that needs concat
    # this needs to be in output order, otherwise range modification would fail
    sliced_out_ranges = []
    for input_id, table_id in enumerate(input_table_map):
      if len(sliced_configs[table_id]) > 1:
        sliced_out_ranges.append([input_id, input_id + len(sliced_configs[table_id])])
    return sliced_configs, sliced_out_ranges

  def create_row_sliced_configs(self, global_row_configs, world_size):
    # initial test code. not considering corner cases
    sliced_configs, offsets = [], []
    for orig_config in global_row_configs:
      sliced_config, offset = [], []
      cur_offset = 0
      row_per_slice = orig_config['input_dim'] // world_size
      remainder = orig_config['input_dim'] % world_size
      for i in range(world_size):
        config = orig_config.copy()
        config['input_dim'] = row_per_slice
        if i < remainder:
          config['input_dim'] += 1
        sliced_config.append(config)
        offset.append(cur_offset)
        cur_offset -= config['input_dim']
      sliced_configs.append(sliced_config)
      offsets.append(offset)
    # re-divide lists by rank
    sliced_configs = [list(rank_configs) for rank_configs in zip(*sliced_configs)]
    offsets = [list(rank_offsets) for rank_offsets in zip(*offsets)]
    return sliced_configs, offsets

  # pylint: disable=missing-param-doc,missing-type-doc,missing-raises-doc
  def apply_stragety(self, mode, world_size, sliced_configs):
    """Distribute tables to workers from sliced config, a nested list.
    Returns:
      divided_ids (list): world_size length list. Each element is list of
    sliced table ids distribute to rank according to position.
    """
    global_ids = []
    table_sizes = []
    for i, sliced_config in enumerate(sliced_configs):
      for config in sliced_config:
        global_ids.append(i)
        table_sizes.append(config['input_dim'] * config['output_dim'])

    # Round-robin distribute tables onto workers
    if mode == 'basic':
      divided_ids = [global_ids[i::world_size] for i in range(world_size)]
    # Distributed table so that memory is balanced while table count remain even
    elif mode == 'memory_balanced':
      sorted_ids = [idx for _, idx in sorted(zip(table_sizes, global_ids), reverse=True)]
      divided_ids = [
          sorted_ids[i::2 * world_size] + sorted_ids[(2 * world_size - 1 - i)::2 * world_size]
          for i in range(world_size)
      ]
    # Try to optimize for total memory first. After sorted by size, table are distributed one by one
    # to worker with lowest total size. Memory usage will be more even but table count may not.
    elif mode == 'memory_optimized':
      sorted_pairs = list(sorted(zip(table_sizes, global_ids)))
      res = [[0, []] for _ in range(world_size)]
      while sorted_pairs:
        cur = sorted_pairs.pop()
        res[0][0] += cur[0]
        res[0][1].append(cur[1])
        res = sorted(res)
      divided_ids = [r[1] for r in res]
    else:
      raise ValueError(F"Unsupported strategy {strategy}")
    return divided_ids

  # Concat table so different table now become shared embedding. XLA does rest of optimization.
  def _create_concat(self, table_configs, input_maps):
    # first get local table id into groups
    grouped_table_ids, concat_configs = [], []
    for table_id, config in enumerate(table_configs):
      for group, concat_config in zip(grouped_table_ids, concat_configs):
        if config['output_dim'] == concat_config['output_dim'] and config.get(
            'combiner') == concat_config.get('combiner'):
          group.append(table_id)
          concat_config['input_dim'] += config['input_dim']
          concat_config['input_dims'].append(config['input_dim'])
          concat_config['offsets'].append(concat_config['offsets'][-1] + config['input_dim'])
          break
      else:  # can't merge with any group, create a new one
        grouped_table_ids.append([table_id])
        config['input_dims'] = [config['input_dim']]
        config['offsets'] = [0, config['input_dim']]
        concat_configs.append(config)

    # adjust input map and create according offset map
    new_input_map, input_offsets = [], []
    for input_map in input_maps:
      for gid, (group, concat_config) in enumerate(zip(grouped_table_ids, concat_configs)):
        if input_map in group:
          new_input_map.append(gid)
          input_offsets.append(concat_config['offsets'][group.index(input_map)])
          break

    # switch to concat initializer to keep behavior associated with shape
    for concat_config in concat_configs:
      input_dims = concat_config.pop('input_dims')
      if len(input_dims) > 1:
        # TODO(deyuf): we don't really need serialize and can just get from original class
        if 'embeddings_initializer' in concat_config:
          orig_initializer = initializers.deserialize(concat_config['embeddings_initializer'])
          concat_config['embeddings_initializer'] = ConcatInitializer(orig_initializer, input_dims)

    # record weight offsets for get/set.
    weight_offsets = [concat_config.pop('offsets', None) for concat_config in concat_configs]
    return concat_configs, new_input_map, input_offsets, grouped_table_ids, weight_offsets

  # Helper function to re-merge slices of same table in cases they end up on same workers
  def _merge_slices(self, rank_table_ids, sliced_configs):
    merged_table_ids, rank_configs = [], []
    for table_idx in rank_table_ids:
      # this id has been seen on this rank before, merge it with earlier shard
      if table_idx in merged_table_ids:
        config_to_merge = sliced_configs[table_idx].pop(0)
        index_to_merge = merged_table_ids.index(table_idx)
        rank_configs[index_to_merge]['output_dim'] += config_to_merge['output_dim']
        # modify output concat ranges
        for out_range in self.sliced_out_ranges:
          if out_range[0] == table_idx:
            out_range[-1] -= 1
      else:
        merged_table_ids.append(table_idx)
        rank_configs.append(sliced_configs[table_idx].pop(0))
    return merged_table_ids, rank_configs


[docs]class DistributedEmbedding(tf.keras.layers.Layer): """Distributed embedding wrapper This class is a hybrid parallel wrapper around embedding. It handles all to all communication of forward and backward of embedding. Args: embeddings (list of keras Embedding layers): embedding tables to be distributed strategy (str): A string indicates how embedding tables are distributed. Choices are [“basic”, “memory_balanced”]. Default "basic" column_slice_threshold (int or None): If None, column slice only happen when there are more workers than tables. In that case, column_slice_threshold will be choose automatically so each worker receive at least one slice. If not None, embedding tables with more elements than column_slice_threshold will be divide into N even pieces alone embedded width dimension. N is smallest power of 2 makes each slice smaller than column_slice_threshold. Default None. row_slice_threshold: Embedding larger than this will be evenly row sliced onto all workers dp_input (bool): If True, takes data parallel input, i.e. in shape [local_batch_size x global_num_embeddings]. Otherwise take model parall input in shape [global_batch_size x local_num_embeddings]. Default True. input_table_map (list or None): same length list as inputs, map `input[i]` to `table[input_table_map[i]]`. None means there are same number of inputs/tables and `input[i]` map to `table[i]`. Default None. """ def __init__(self, embeddings, strategy="basic", column_slice_threshold=None, row_slice_threshold=None, dp_input=True, input_table_map=None, data_parallel_threshold=None, **kwargs): super().__init__(**kwargs) if strategy not in ['basic', 'memory_balanced', 'memory_optimized']: raise ValueError(F"Unsupported shard strategy {strategy}") # Currently assume data parallel ranks == model parallel ranks # TODO(deyuf): add more control over this with newly added hvd process_set api if not hvd.is_initialized(): hvd.init() self.world_size = hvd.size() self.rank = hvd.rank() # single worker case fallback to no dp and no row slice. # ideally we could fallback to dp, but do mp for mp_input backward compatibilty self.dp_input = dp_input self.column_slice_threshold = column_slice_threshold if self.world_size > 1: self.row_slice_threshold = row_slice_threshold if dp_input else None self.data_parallel_threshold = data_parallel_threshold if dp_input else None else: self.row_slice_threshold = None self.data_parallel_threshold = None # get model parallel distribution strategy self.strategy = DistEmbeddingStrategy(embeddings, self.world_size, strategy, input_table_map=input_table_map, column_slice_threshold=column_slice_threshold, row_slice_threshold=self.row_slice_threshold, data_parallel_threshold=self.data_parallel_threshold) # Here we make sure empty lists exist # create data parallel layers self.dp_layers = [] if self.strategy.table_groups[0]: for config in self.strategy.dp_configs: self.dp_layers.append(self._create_layer_from_config(config)) # create (maybe) column sliced embeddings and table parallel. self.local_embedding_layers = [] self.col_inputs_offsets = [] if self.strategy.table_groups[1]: # Handle explicit threshold or corner cases, in which worker may receive no configs # Column slice still need to expand all gpu, otherwise alltoall fails if not all(rank_configs for rank_configs in self.strategy.local_configs): raise ValueError("Not enough table after slicing to run on all worker." "Try decrease column_slice_threshold or decrease worker count") for config in self.strategy.local_configs[self.rank]: self.local_embedding_layers.append(self._create_layer_from_config(config)) self.col_inputs_offsets = [ None if offset == 0 else tf.constant([offset], dtype=tf.int64) for offset in self.strategy.local_input_offsets[self.rank] ] # create row sliced embeddings. self.row_layers = [] self.row_inputs_offsets = [] if self.strategy.table_groups[2]: for config in self.strategy.row_sliced_configs[self.rank]: self.row_layers.append(self._create_layer_from_config(config)) self.row_inputs_offsets = [ None if offset == 0 else tf.constant([offset], dtype=tf.int64) for offset in self.strategy.row_inputs_offsets[self.rank] ] def _create_layer_from_config(self, config): # For stock keras Embedding, we switch underlying layer for better performance # If inputs are custom layers, original layer will be used layer_type = config.pop('layer_type') layer_type = Embedding if layer_type == tf.keras.layers.Embedding else layer_type return layer_type.from_config(config) def _call_data_parallel(self, inputs): outputs = [self.dp_layers[m](inp) for m, inp in zip(self.strategy.map_groups[0], inputs)] return outputs def _call_table_parallel(self, inputs): # pylint: disable=missing-param-doc,missing-type-doc """Call function that do embeddings and communication Currently, it requires same batch_size on all workers. """ # get model parallel input from data parallel if self.dp_input: if self.world_size > 1: comm_dtype = tf.int32 for inp in inputs: if inp.dtype == tf.int64: comm_dtype = tf.int64 inputs = [tf.cast(inp, comm_dtype) for inp in inputs] local_shapes, local_splits, global_splits, flat_inputs = [], [], [], [] for rank_input_ids in self.strategy.input_ids_list: rank_inputs = [inputs[index] for index in rank_input_ids] local_shapes.append([_get_shape(inp) for inp in rank_inputs]) rank_inputs = [tf.reshape(inp, [-1]) for inp in rank_inputs] local_splits.append([_get_shape(inp)[0] for inp in rank_inputs]) global_splits.append(sum(local_splits[-1])) flat_inputs += rank_inputs inputs = tf.concat(flat_inputs, 0) inputs, _ = hvd.alltoall(inputs, splits=global_splits, name='inp_dp_to_mp') inputs = tf.reshape(inputs, [self.world_size, -1]) inputs = tf.split(inputs, local_splits[self.rank], 1) inputs = [ tf.reshape(inp, [self.world_size * shape[0]] + shape[1:]) for inp, shape in zip(inputs, local_shapes[self.rank]) ] else: # expected input order may still change in case of single process inputs = [inputs[idx] for idx in self.strategy.input_ids_list[0]] if len(inputs) != len(self.strategy.local_maps[self.rank]): raise ValueError(F"Expect {self.strategy.local_maps[self.rank]} inputs, got {len(inputs)}.") # offset inputs inputs = [ inp if offset is None else tf.cast(inp, tf.int64) + offset for inp, offset in zip(inputs, self.col_inputs_offsets) ] # do embedding mp_outs = [ self.local_embedding_layers[m](inp) for m, inp in zip(self.strategy.local_maps[self.rank], inputs) ] mp_outs = [tf.cast(output, self.compute_dtype) for output in mp_outs] if self.world_size > 1: # TODO(deyuf): current assume 2D with same batch for all output, ideally should support general case mp_outs = [tf.reshape(mp_out, [self.world_size, -1]) for mp_out in mp_outs] mp_outs = tf.reshape(tf.concat(mp_outs, axis=1), [-1]) dp_outs = hvd.alltoall(mp_outs, name='out_mp_to_dp') batch_size = tf.shape( inputs[0], out_type=tf.int32)[0] if inputs[0].shape[0] is None else inputs[0].shape[0] local_bs = batch_size // self.world_size num_elements = [local_bs * width for width in self.strategy.widths_list_flat] split_outs = tf.split(dp_outs, num_elements) mp_outs = [tf.reshape(split_out, [local_bs, -1]) for split_out in split_outs] # reorder outputs to be same as inputs order result = [mp_outs[index] for index in self.strategy.rev_tp_ids] # Concat sliced outputs result from column slicing back together for start, end in self.strategy.sliced_out_ranges: result[start:end] = [tf.concat(result[start:end], axis=-1)] return result def _call_row_slice(self, inputs): # initial version, just allgather input, do lookup and allreduce output # for lookup that does not exist on this worker(OOB), zero vector is added in inputs = hvd.grouped_allgather(inputs) # offset inputs inputs = [ inp if offset is None else tf.cast(inp, tf.int64) + offset for inp, offset in zip(inputs, self.row_inputs_offsets) ] # do embedding outputs = [self.row_layers[m](inp) for m, inp in zip(self.strategy.map_groups[2], inputs)] outputs = grouped_reducescatter_unscaled(outputs) return outputs def set_col_slice_weights(self, weights): if not weights: return [] if self.world_size == 1: if isinstance(weights[0], str): weights = [np.load(file=path, mmap_mode='r') for path in weights] else: slice_info = [[rank_table_id.count(table_id) for rank_table_id in self.strategy.table_ids] for table_id in range(len(weights))] local_info = [slice_info[index] for index in self.strategy.table_ids[self.rank]] weights = [weights[index] for index in self.strategy.table_ids[self.rank]] if isinstance(weights[0], str): weights = [np.load(file=path, mmap_mode='r') for path in weights] def _slice_weight_for_rank(weight, info, global_rank): num_columns = weight.shape[1] num_slices = sum(info) column_per_slice = num_columns // num_slices remainder = num_columns % num_slices rank = sum(info[:global_rank]) start = column_per_slice * rank + min(rank, remainder) rank += 1 end = column_per_slice * rank + min(rank, remainder) return weight[:, start:end] weights = [ _slice_weight_for_rank(weight, info, self.rank) for weight, info in zip(weights, local_info) ] # now we have weight distributed, need to concat concat_weights = [] for group in self.strategy.local_group_list[self.rank]: to_concat = [weights[idx] for idx in group] concat_weights.append(np.concatenate(to_concat)) return concat_weights def set_row_slice_weights(self, weights): # we make sure no table is in this group in single worker case if not weights: return [] if isinstance(weights[0], str): weights = [np.load(file=path, mmap_mode='r') for path in weights] local_info = [[1 for _ in range(self.world_size)] for _ in weights] def _slice_weight_for_rank(weight, info, global_rank): num_columns = weight.shape[0] num_slices = sum(info) column_per_slice = num_columns // num_slices remainder = num_columns % num_slices rank = sum(info[:global_rank]) start = column_per_slice * rank + min(rank, remainder) rank += 1 end = column_per_slice * rank + min(rank, remainder) return weight[start:end, :] weights = [ _slice_weight_for_rank(weight, info, self.rank) for weight, info in zip(weights, local_info) ] return weights
[docs] def set_weights(self, weights, chunk=134217728, use_lock=False): """Sets the weights of the layer, from NumPy arrays. Args: weights (list): list containing global weights for all table. item in the list can be either numpy array or file path to load from. chunk (int): max number of elements per chunk when set weight on GPU by chunks. this will be round to number of rows base on weight shape. use_lock (bool): If true, set weights rank by rank in lock step to avoid OOM. Default False. Raises: ValueError: If length of weights does not match length of expected weights. """ if use_lock: for _ in range(self.rank): hvd.broadcast_object(0) dp_weights = [weights[idx] for idx in self.strategy.table_groups[0]] col_weights = [weights[idx] for idx in self.strategy.table_groups[1]] row_weights = [weights[idx] for idx in self.strategy.table_groups[2]] col_weights = self.set_col_slice_weights(col_weights) row_weights = self.set_row_slice_weights(row_weights) weights = dp_weights + col_weights + row_weights # variable.assign and copy-on-write creates extra copy of weight that causes OOM # so here we scatter update by ~128M elements chunks instead of just do # super().set_weights(weights) if len(self.weights) != len(weights): raise ValueError( F"You called `set_weights(weights)` on layer {self.name} with a weight list of " F"length {len(weights)}, but the layer was expecting {len(self.weights)} weights.") for weight, arr in zip(self.weights, weights): if arr.size <= chunk: weight.assign(arr) else: chunk_size_dim0 = chunk // weight.shape[1] num_chunks = math.ceil(weight.shape[0] / chunk_size_dim0) last_size = weight.shape[0] - chunk_size_dim0 * (num_chunks - 1) chunk_sizes = [chunk_size_dim0] * (num_chunks - 1) + [last_size] for i in range(num_chunks): start = i * chunk_size_dim0 end = start + chunk_sizes[i] indices = tf.range(start=start, limit=end, dtype=tf.int64) update = tf.IndexedSlices(values=arr[start:end], indices=indices, dense_shape=weight.shape) weight.scatter_update(sparse_delta=update) del weights if use_lock: for _ in range(self.world_size - self.rank): hvd.broadcast_object(0)
# 1d split that works beyond 32bit indexing limit TF support def _split_1d(self, tensor, lengths): # choose a number close to int32 limit as maximum chunk size # This will handle tensor with size up to square of int32_max chunking_threshold = 2147483646 if tensor.shape[0] <= chunking_threshold: return tf.split(tensor, lengths) num_chunks = math.ceil(tensor.shape[0] / chunking_threshold) padding_len = math.ceil(tensor.shape[0] / num_chunks) * num_chunks - tensor.shape[0] padded_tensor = tf.concat([tensor, tf.zeros(padding_len, tensor.dtype)], axis=0) tensor_list = tf.unstack(tf.reshape(padded_tensor, [num_chunks, -1])) result = [] for length in lengths: this_slice = [] while length > 0: if length > tensor_list[0].shape[0]: this_slice.append(tensor_list.pop(0)) else: this_slice.append(tensor_list[0][:length]) tensor_list[0] = tensor_list[0][length:] length -= this_slice[-1].shape[0] result.append(tf.concat(this_slice, axis=0)) return result def get_row_sliced_weights(self, weights): if not weights: return [] # weights are already selected with group info # assume row slice run on all workers, then allgather conveniently stitch them together weights = hvd.grouped_allgather(weights) return [w.numpy() for w in weights] def get_col_sliced_weights(self, local_weights, all_ranks=False): if not local_weights: return [] # TODO(deyuf): undo concat locally first. this require we save original local config if self.world_size == 1: concat_weights = [w.numpy() for w in local_weights] res = [item for sublist in self.strategy.local_group_list[0] for item in sublist] for offsets, f_w, group in zip(self.strategy.local_weight_offsets[0], concat_weights, self.strategy.local_group_list[0]): for i in range(len(offsets) - 1): res[group[i]] = f_w[offsets[i]:offsets[i + 1]] return res # mpi segfault on over 32bit range index, so we gather weights chunk by chunk here # choose a number not very close to int32 limit as maximum chunk size just to be safe chunking_threshold = 2000000000 num_chunks = 1 for local_config in self.strategy.local_configs: total_elements = sum([c['input_dim'] * c['output_dim'] for c in local_config]) num_chunks = max(num_chunks, math.ceil(self.world_size * total_elements / chunking_threshold)) with tf.device('CPU:0'): local_weights = tf.concat([tf.reshape(w, [-1]) for w in local_weights], axis=0) chunk_size = local_weights.shape[0] // num_chunks last_size = local_weights.shape[0] - chunk_size * (num_chunks - 1) chunk_sizes = [chunk_size] * (num_chunks - 1) + [last_size] local_weights = self._split_1d(local_weights, chunk_sizes) # communicate chunk sizes all_sizes = hvd.allgather(chunk_sizes) # collect all chunks and split to reverse allgather concat chunks = [] for i, w in enumerate(local_weights): w = hvd.allgather(w) if all_ranks or self.rank == 0: chunks += self._split_1d(w, all_sizes[i::num_chunks]) if not chunks: return [] # re-construct all local weights from chunks local_weights = [] for i in range(self.world_size): local_weights.append(tf.concat(chunks[i::self.world_size], axis=0)) del chunks # split flat local weights into correct sizes weights = [] for local_weight, local_config, weight_offsets, local_groups in zip( local_weights, self.strategy.local_configs, self.strategy.local_weight_offsets, self.strategy.local_group_list): local_shapes = [[c['input_dim'], c['output_dim']] for c in local_config] local_sizes = [shape[0] * shape[1] for shape in local_shapes] flat_weights = self._split_1d(local_weight, local_sizes) concat_weights = [ tf.reshape(weight, shape) for weight, shape in zip(flat_weights, local_shapes) ] # split concat embedding weights res = [item for sublist in local_groups for item in sublist] for offsets, f_w, group in zip(weight_offsets, concat_weights, local_groups): for i in range(len(offsets) - 1): res[group[i]] = f_w[offsets[i]:offsets[i + 1]] weights += res # restore original table order # flatten self.strategy.table_ids worker_order_table_ids = [item for sublist in self.strategy.table_ids for item in sublist] # Shuffle worker ordered embedding weights(sliced) back to original order. ids_and_weights = sorted(zip(worker_order_table_ids, weights), key=lambda x: x[0]) # concat sliced weights result = [] cur_id = 0 cur_list = [] while ids_and_weights: cur = ids_and_weights.pop(0) if cur[0] == cur_id: cur_list.append(cur[1]) else: result.append(tf.concat(cur_list, axis=1).numpy()) cur_id = cur[0] cur_list = [cur[1]] result.append(tf.concat(cur_list, axis=1).numpy()) return result
[docs] def get_weights(self, all_ranks=False): """Returns the current weights of the layer, as NumPy arrays. This override outputs global weights for all tables. Args: all_ranks (bool): If true, return weights in all ranks, otherwise only in rank 0. Default False. Returns: result (list): List of weight tensors. """ # avoid copy-on-read on dense access, assume order here for code simplicity weights = [read_var_no_copy(w) for w in self.weights] num_dp, num_col = len(self.dp_layers), len(self.local_embedding_layers) dp_weights = weights[:num_dp] col_weights = weights[num_dp:num_dp + num_col] row_weights = weights[num_dp + num_col:] col_weights = self.get_col_sliced_weights(col_weights, all_ranks) row_weights = self.get_row_sliced_weights(row_weights) weights = dp_weights + col_weights + row_weights group_order_table_ids = [idx for group in self.strategy.table_groups for idx in group] weights = [w for _, w in sorted(zip(group_order_table_ids, weights))] return weights
@tf_utils.shape_type_conversion def build(self, input_shape): if input_shape is not None and None not in input_shape[0]: # Do some checks to detect cases that are not supported if not isinstance(input_shape, (list, tuple)): input_shape = [input_shape] batch_sizes = [shape[0] for shape in input_shape] batch_sizes = hvd.allgather(batch_sizes).numpy().tolist() if len(set(batch_sizes)) > 1: raise ValueError(F"All input need to have same batchsize. got {set(batch_sizes)}.") if not self.dp_input: if batch_sizes[0] % self.world_size > 0: raise ValueError( F"Global batchsize {batch_sizes[0]} not divisible workers count {self.world_size}.") # build both col and row slice tables for layer in self.dp_layers: layer.build(input_shape[0] if input_shape else None) # set built flag to prevent above build trigger again and above flag fall off layer.built = True # build both col and row slice tables for layer in self.local_embedding_layers + self.row_layers: layer.build(input_shape[0] if input_shape else None) for var in layer.trainable_weights: # Mark local(model parallel) variable. use prefix de(distributed embeddings) to avoid conflicts. var.de_local = True # set built flag to prevent above build trigger again and above flag fall off layer.built = True self.built = True def call(self, inputs): # pylint: disable=missing-function-docstring # call data parallel tables dp_in = [inputs[idx] for idx in self.strategy.input_groups[0]] dp_out = self._call_data_parallel(dp_in) if dp_in else [] # call col slice tables col_in = [inputs[idx] for idx in self.strategy.input_groups[1]] if self.dp_input else inputs col_out = self._call_table_parallel(col_in) if col_in else [] # call row slice tables row_in = [inputs[idx] for idx in self.strategy.input_groups[2]] row_out = self._call_row_slice(row_in) if row_in else [] # now we have output from all groups, reorder them into input order outputs = dp_out + col_out + row_out outputs = [outputs[idx] for idx in self.strategy.rev_group_ids] return outputs
# Monkey patch horovod bcast/tape so we can handle mp/dp vars differently in single backward # pylint: disable=protected-access, missing-any-param-doc, invalid-name
[docs]def broadcast_variables(model_vars, root_rank=0): """Broadcasts variables from root rank to all other processes in a process set Replace horovod's broadcast_variables when running hybrid parallel See https://horovod.readthedocs.io/en/stable/api.html for more details """ dp_vars = [] mp_vars = [] for var in model_vars: if hasattr(var, 'de_local'): mp_vars.append(var) else: dp_vars.append(var) # modify broadcast to ignore_name_scope by default # TODO(deyuf): make it not positional _broadcast_defaults = list(hvd.broadcast.__defaults__) _broadcast_defaults[1] = True hvd.broadcast.__defaults__ = tuple(_broadcast_defaults) hvd.broadcast_variables(dp_vars, root_rank=root_rank)
[docs]def DistributedGradientTape(*args, **kwargs): """Graident tape that supports hybrid parallel Replace horovod's DistributedGradientTape when running hybrid parallel See https://horovod.readthedocs.io/en/stable/api.html for more details """ def _gradient(self, target, sources, *args, **kwargs): # Overwrite use_generic_names to always be True kwargs["use_generic_names"] = True gradients = self.raw_gradient(target, sources, *args, **kwargs) return gradients if horovod.__version__ < '0.27.0': raise NotImplementedError( "DistributedGradientTape is only compatible with horovod 0.27 or newer.") tape = hvd.DistributedGradientTape(sparse_as_dense=True, *args, **kwargs) for var in tape.watched_variables(): if hasattr(var, 'de_local'): tape.register_local_source(var) tape.raw_gradient = tape.gradient tape.gradient = types.MethodType(_gradient, tape) return tape
def DistributedOptimizer(*args, **kwargs): """Distributed optimizer that supports hybrid parallel Replace horovod's DistributedOptimizer when running hybrid parallel See https://horovod.readthedocs.io/en/stable/api.html for more details """ # might be correct to patch get/aggregate gradient, but those seems already messy def _register_then_allreduce(self, grads, model_vars): if not self.local_var_registed: for var in model_vars: if hasattr(var, 'de_local'): self.register_local_var(var) self.local_var_registed = True return self.raw_allreduce(grads, model_vars) if horovod.__version__ < '0.27.0': raise NotImplementedError("Distributed Optimizer is only compatible with horovod 0.27 or newer") opt = hvd_keras.DistributedOptimizer(sparse_as_dense=True, *args, **kwargs) opt.local_var_registed = False opt.raw_allreduce = opt._allreduce opt._allreduce = types.MethodType(_register_then_allreduce, opt) # need to patch internal allreduce call with use_generic_names def _named_allreduce_grads(self, grads, variables): return self.raw_allreduce_grads(grads, variables, use_generic_names=True) opt.raw_allreduce_grads = opt._allreduce_grads opt._allreduce_grads = types.MethodType(_named_allreduce_grads, opt) return opt def BroadcastGlobalVariablesCallback(*args, **kwargs): """Broadcast callback that supports hybrid parallel Replace horovod's BroadcastGlobalVariablesCallback when running hybrid parallel See https://horovod.readthedocs.io/en/stable/api.html for more details """ def _on_batch_end(self, batch, logs=None): if not self.local_var_registed: for var in self.model.variables: if hasattr(var, 'de_local'): self.register_local_var(var) self.local_var_registed = True return self.raw_on_batch_end(batch, logs) if horovod.__version__ < '0.27.0': raise NotImplementedError( "BroadcastGlobalVariablesCallback is only compatible with horovod 0.27 or newer.") bcb = hvd_keras.callbacks.BroadcastGlobalVariablesCallback(*args, **kwargs) bcb.local_var_registed = False bcb.raw_on_batch_end = bcb.on_batch_end bcb.on_batch_end = types.MethodType(_on_batch_end, bcb) return bcb # pylint: enable=protected-access, missing-any-param-doc, invalid-name