#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import collections.abc
from dask.core import flatten
from nvtabular.ops import LambdaOp, Operator
[docs]class ColumnGroup:
"""A ColumnGroup is a group of columns that you want to apply the same transformations to.
ColumnGroup's can be transformed by shifting operators on to them, which returns a new
ColumnGroup with the transformations applied. This lets you define a graph of operations
that makes up your workflow
Parameters
----------
columns: list of (str or tuple of str)
The columns to select from the input Dataset. The elements of this list are strings
indicating the column names in most cases, but can also be tuples of strings
for feature crosses.
"""
def __init__(self, columns):
self.parents = []
self.children = []
self.op = None
self.kind = None
self.dependencies = None
if isinstance(columns, str):
columns = [columns]
# if any of the values we're passed are a columngroup
# we have to ourselves as a childnode in the graph.
if any(isinstance(col, ColumnGroup) for col in columns):
self.columns = []
self.kind = "[...]"
for col in columns:
if not isinstance(col, ColumnGroup):
col = ColumnGroup(col)
else:
# we can't handle nesting arbitrarily deep here
# only accept non-nested (str) columns here
if any(not isinstance(c, str) for c in col.columns):
raise ValueError("Can't handle more than 1 level of nested columns")
col.children.append(self)
self.parents.append(col)
self.columns.append(tuple(col.columns))
else:
self.columns = [_convert_col(col) for col in columns]
[docs] def __rshift__(self, operator):
"""Transforms this ColumnGroup by applying an Operator
Parameters
-----------
operators: Operator or callable
Returns
-------
ColumnGroup
"""
if isinstance(operator, type) and issubclass(operator, Operator):
# handle case where an operator class is passed
operator = operator()
elif callable(operator):
# implicit lambdaop conversion.
operator = LambdaOp(operator)
if not isinstance(operator, Operator):
raise ValueError(f"Expected operator or callable, got {operator.__class__}")
child = ColumnGroup(operator.output_column_names(self.columns))
child.parents = [self]
self.children.append(child)
child.op = operator
dependencies = operator.dependencies()
if dependencies:
child.dependencies = set()
if not isinstance(dependencies, collections.abc.Sequence):
dependencies = [dependencies]
for dependency in dependencies:
if not isinstance(dependency, ColumnGroup):
dependency = ColumnGroup(dependency)
dependency.children.append(child)
child.parents.append(dependency)
child.dependencies.add(dependency)
return child
[docs] def __add__(self, other):
"""Adds columns from this ColumnGroup with another to return a new ColumnGroup
Parameters
-----------
other: ColumnGroup or str or list of str
Returns
-------
ColumnGroup
"""
if isinstance(other, str):
other = ColumnGroup([other])
elif isinstance(other, collections.abc.Sequence):
other = ColumnGroup(other)
# check if there are any columns with the same name in both column groups
overlap = set(self.columns).intersection(other.columns)
if overlap:
raise ValueError(f"duplicate column names found: {overlap}")
child = ColumnGroup(self.columns + other.columns)
child.parents = [self, other]
child.kind = "+"
self.children.append(child)
other.children.append(child)
return child
# handle the "column_name" + ColumnGroup case
__radd__ = __add__
[docs] def __sub__(self, other):
"""Removes columns from this ColumnGroup with another to return a new ColumnGroup
Parameters
-----------
other: ColumnGroup or str or list of str
Columns to remove
Returns
-------
ColumnGroup
"""
if isinstance(other, ColumnGroup):
to_remove = set(other.columns)
elif isinstance(other, str):
to_remove = {other}
elif isinstance(other, collections.abc.Sequence):
to_remove = set(other)
else:
raise ValueError(f"Expected ColumnGroup, str, or list of str. Got {other.__class__}")
new_columns = [c for c in self.columns if c not in to_remove]
child = ColumnGroup(new_columns)
child.parents = [self]
self.children.append(child)
child.kind = f"- {list(to_remove)}"
return child
[docs] def __getitem__(self, columns):
"""Selects certain columns from this ColumnGroup, and returns a new Columngroup with only
those columns
Parameters
-----------
columns: str or list of str
Columns to select
Returns
-------
ColumnGroup
"""
if isinstance(columns, str):
columns = [columns]
child = ColumnGroup(columns)
child.parents = [self]
self.children.append(child)
child.kind = str(columns)
return child
def __repr__(self):
output = " output" if not self.children else ""
return f"<ColumnGroup {self.label}{output}>"
@property
def flattened_columns(self):
return list(flatten(self.columns, container=tuple))
@property
def input_column_names(self):
"""Returns the names of columns in the main chain"""
dependencies = self.dependencies or set()
return [
col for parent in self.parents for col in parent.columns if parent not in dependencies
]
@property
def label(self):
if self.op:
return self.op.label
elif self.kind:
return self.kind
elif not self.parents:
return f"input cols=[{self._cols_repr}]"
else:
return "??"
@property
def _cols_repr(self):
cols = ", ".join(map(str, self.columns[:3]))
if len(self.columns) > 3:
cols += "..."
return cols
@property
def graph(self):
return _to_graphviz(self)
def iter_nodes(nodes):
queue = nodes[:]
while queue:
current = queue.pop()
yield current
# TODO: deduplicate nodes?
for parent in current.parents:
queue.append(parent)
[docs]def _to_graphviz(column_group):
"""Converts a ColumnGroup to a GraphViz DiGraph object useful for display in notebooks"""
from graphviz import Digraph
column_group = _merge_add_nodes(column_group)
graph = Digraph()
# get all the nodes from parents of this columngroup
# and add edges between each of them
allnodes = list(set(iter_nodes([column_group])))
node_ids = {v: str(k) for k, v in enumerate(allnodes)}
for node, nodeid in node_ids.items():
graph.node(nodeid, node.label)
for parent in node.parents:
graph.edge(node_ids[parent], nodeid)
# add a single 'output' node representing the final state
output_node_id = str(len(allnodes))
graph.node(output_node_id, f"output cols=[{column_group._cols_repr}]")
graph.edge(node_ids[column_group], output_node_id)
return graph
[docs]def _merge_add_nodes(graph):
"""Merges repeat '+' nodes, leading to nicer looking outputs"""
# lets take a copy to avoid mutating the input
import copy
graph = copy.copy(graph)
queue = [graph]
while queue:
current = queue.pop()
if current.kind == "+":
changed = True
while changed:
changed = False
parents = []
for i, parent in enumerate(current.parents):
if parent.kind == "+" and len(parent.children) == 1:
changed = True
# disconnect parent, point all the grandparents at current instead
parents.extend(parent.parents)
for grandparent in parent.parents:
grandparent.children = [
current if child == parent else child
for child in grandparent.children
]
else:
parents.append(parent)
current.parents = parents
queue.extend(current.parents)
return graph
def _convert_col(col):
if isinstance(col, (str, tuple)):
return col
elif isinstance(col, list):
return tuple(col)
else:
raise ValueError(f"Invalid column value for ColumnGroup: {col}")