#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import warnings
try:
    import cudf
except ImportError:
    cudf = None
import dask.dataframe as dd
import pandas as pd
from merlin.core.dispatch import (
    DataFrameType,
    ExtData,
    arange,
    convert_data,
    create_merlin_dataset,
    detect_format,
    to_host,
)
from merlin.schema import Schema
from .operator import ColumnSelector, Operator
[docs]class JoinExternal(Operator):
    """
    Join each dataset partition to an external table. For performance
    reasons, only "left" and "inner" join transformations are supported.
    Example usage::
        # Load dataset which should be joined to the main dataset
        df_external = cudf.read_parquet('external.parquet')
        # Use JoinExternal to define a NVTabular workflow
        joined = ColumnSelector(columns_left) >> nvt.ops.JoinExternal(
            df_ext,
            on=['key1', 'key2'],
            on_ext=['key1_ext', 'key2_ext'],
            how='left',
            columns_ext=['key1_ext', 'key2_ext', 'cat1', 'cat2', 'num1'],
            kind_ext='cudf',
            cache='device'
        ) >> ...
        processor = nvtabular.Workflow(joined)
    Parameters
    -----------
    df_ext : DataFrame, pyarrow.Table, Dataset, dd.DataFrame, or file path(s)
        The external table to join to each partition of the dataset. Note
        that the join must be a partition-wise transformation. Therefore,
        if ``df_ext`` is a multi-partition Dask collection, it will need to
        be broadcasted to every partition.
    on : str or list(str)
        Column name(s) to merge on
    how : {"left", "inner"}; default "left"
        Type of join operation to perform.
    on_ext : str or list(str); Optional
        Column name(s) on external table to join on. By default,
        we assume ``on_ext`` is the same as ``on``.
    columns_ext : list(str); Optional
        Subset of columns to select from external table before join.
    drop_duplicates_ext : bool; Default False
        Drop duplicates from external table before join.
    kind_ext : ExtData; Optional
        Format of ``df_ext``.  If nothing is specified, the format
        will be inferred.
    cache : {"device", "host", "disk"}
        Where to cache ``df_ext`` between transformations. Only used
        if the data is originally stored on disk. The "host" option
        is also supported when ``df_ext`` is a ``cudf.DataFrame``.
    """
    def __init__(
        self,
        df_ext,
        on,
        how="left",
        on_ext=None,
        columns_ext=None,
        drop_duplicates_ext=None,
        kind_ext=None,
        cache="host",
        **kwargs,
    ):
        super(JoinExternal).__init__()
        self.on = on
        self.df_ext = create_merlin_dataset(df_ext)
        self.on_ext = on_ext or self.on
        self.how = how
        self.kind_ext = kind_ext or detect_format(self.df_ext)
        self.columns_ext = columns_ext
        self.drop_duplicates_ext = drop_duplicates_ext
        self.cache = cache
        self.kwargs = kwargs
        self.cpu = None
        self._ext_cache = None
        if cudf is None:
            self.cpu = True
        if self.how not in ("left", "inner"):
            raise ValueError("Only left join is currently supported.")
        if not isinstance(self.kind_ext, ExtData):
            raise ValueError("kind_ext option not recognized.")
        super().__init__()
    @property
    def _ext(self):
        if self._ext_cache is not None:
            # Return cached result if present
            return convert_data(self._ext_cache, cpu=self.cpu)
        # Use Dataset.to_ddf
        _dataset = self.df_ext
        if self.cpu:
            _dataset.to_cpu()
        else:
            _dataset.to_gpu()
        _ext = _check_partition_count(_dataset.to_ddf(columns=self.columns_ext))
        # Take subset of columns if a list is specified
        if self.columns_ext:
            _ext = _ext[self.columns_ext]
        # Drop duplicates if requested
        if self.drop_duplicates_ext:
            if isinstance(_ext, dd.DataFrame):
                _ext = _ext.drop_duplicates(ignore_index=True)
            else:
                _ext.drop_duplicates(ignore_index=True, inplace=True)
        # Cache and return
        if self.cache == "host":
            self._ext_cache = to_host(_ext)
        elif self.cache == "device" or self.kind_ext not in (ExtData.PARQUET, ExtData.CSV):
            self._ext_cache = _ext
        return _ext
    def _merge(self, df, _ext):
        if isinstance(_ext, dd.DataFrame):
            _ddf = dd.from_pandas(df, npartitions=1)
            return _ddf.merge(_ext, left_on=self.on, right_on=self.on_ext, how=self.how).compute()
        else:
            return df.merge(_ext, left_on=self.on, right_on=self.on_ext, how=self.how)
    transform.__doc__ = Operator.transform.__doc__
[docs]    def compute_selector(
        self,
        input_schema: Schema,
        selector: ColumnSelector,
        parents_selector: ColumnSelector,
        dependencies_selector: ColumnSelector,
    ) -> ColumnSelector:
        self._validate_matching_cols(input_schema, parents_selector, "computing input selector")
        return parents_selector 
[docs]    def compute_output_schema(self, input_schema, col_selector, prev_output_schema=None):
        # must load in the schema from the external dataset
        input_schema = input_schema + self.df_ext.schema
        return super().compute_output_schema(input_schema, col_selector, prev_output_schema) 
[docs]    def column_mapping(self, col_selector):
        column_mapping = {}
        ext_columns = self.columns_ext if self.columns_ext else self._ext.columns
        # This maintains the order which set() does not
        combined_col_names = dict.fromkeys(col_selector.names + list(ext_columns)).keys()
        for col_name in combined_col_names:
            column_mapping[col_name] = [col_name]
        return column_mapping 
    def _compute_dtype(self, col_schema, input_schema):
        if col_schema.name in input_schema.column_names:
            return super()._compute_dtype(col_schema, input_schema)
        else:
            col_dtype = self.df_ext.schema.column_schemas[col_schema.name].dtype
            return col_schema.with_dtype(col_dtype)
    def _compute_tags(self, col_schema, input_schema):
        return col_schema
    def _compute_properties(self, col_schema, input_schema):
        return col_schema 
def _check_partition_count(df):
    if hasattr(df, "npartitions"):
        if df.npartitions == 1:
            # Materialize single-partition collections
            return df.compute()
        if df.npartitions > 3:
            warnings.warn(
                f"Joining an external Dask collection with "
                f"{df.npartitions} partitions. This transformation "
                f"requires a broadcast merge, which can be problematic "
                f"when the external collection is too large."
            )
    return df