#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import collections.abc
from typing import List, Union
from merlin.dag.base_operator import BaseOperator
from merlin.dag.ops import ConcatColumns, SelectionOp, SubsetColumns, SubtractionOp
from merlin.dag.selector import ColumnSelector
from merlin.schema import Schema
[docs]
class Node:
"""A Node is a group of columns that you want to apply the same transformations to.
Node's can be transformed by shifting operators on to them, which returns a new
Node with the transformations applied. This lets you define a graph of operations
that makes up your workflow
Parameters
----------
selector: ColumnSelector
Defines which columns to select from the input Dataset using column names and tags.
"""
def __init__(self, selector=None):
self.parents = []
self.children = []
self.dependencies = []
self.op = None
self.input_schema = None
self.output_schema = None
if isinstance(selector, list):
selector = ColumnSelector(selector)
if selector and not isinstance(selector, ColumnSelector):
raise TypeError("The selector argument must be a list or a ColumnSelector")
if selector is not None:
self.op = SelectionOp(selector)
self.selector = selector
@property
def selector(self):
return self._selector
@selector.setter
def selector(self, sel):
if isinstance(sel, list):
sel = ColumnSelector(sel)
self._selector = sel
# These methods must maintain grouping
[docs]
def add_dependency(
self,
dep: Union[
str,
List[str],
ColumnSelector,
"Node",
List[Union[str, List[str], "Node", ColumnSelector]],
],
):
"""
Adding a dependency node to this node
Parameters
----------
dep : Union[str, ColumnSelector, Node, List[Union[str, Node, ColumnSelector]]]
Dependency to be added
"""
dep_node = Node.construct_from(dep)
if not isinstance(dep_node, list):
dep_nodes = [dep_node]
else:
dep_nodes = dep_node
for node in dep_nodes:
node.children.append(self)
self.dependencies.append(dep_node)
[docs]
def add_parent(
self,
parent: Union[
str,
List[str],
ColumnSelector,
"Node",
List[Union[str, List[str], "Node", ColumnSelector]],
],
):
"""
Adding a parent node to this node
Parameters
----------
parent : Union[str, ColumnSelector, Node, List[Union[str, Node, ColumnSelector]]]
Parent to be added
"""
parent_nodes = Node.construct_from(parent)
if not isinstance(parent_nodes, list):
parent_nodes = [parent_nodes]
for parent_node in parent_nodes:
parent_node.children.append(self)
self.parents.extend(parent_nodes)
[docs]
def add_child(
self,
child: Union[
str,
List[str],
ColumnSelector,
"Node",
List[Union[str, List[str], "Node", ColumnSelector]],
],
):
"""
Adding a child node to this node
Parameters
----------
child : Union[str, ColumnSelector, Node, List[Union[str, Node, ColumnSelector]]]
Child to be added
"""
child_nodes = Node.construct_from(child)
if not isinstance(child_nodes, list):
child_nodes = [child_nodes]
for child_node in child_nodes:
child_node.parents.append(self)
self.children.extend(child_nodes)
[docs]
def remove_child(
self,
child: Union[
str,
List[str],
ColumnSelector,
"Node",
List[Union[str, List[str], "Node", ColumnSelector]],
],
):
"""
Removing a child node from this node
Parameters
----------
child : Union[str, ColumnSelector, Node, List[Union[str, Node, ColumnSelector]]]
Child to be removed
"""
child_nodes = Node.construct_from(child)
if not isinstance(child_nodes, list):
child_nodes = [child_nodes]
for child_node in child_nodes:
if self in child_node.parents:
child_node.parents.remove(self)
if child_node in self.children:
self.children.remove(child_node)
[docs]
def compute_schemas(self, root_schema: Schema, preserve_dtypes: bool = False):
"""
Defines the input and output schema
Parameters
----------
root_schema : Schema
Schema of the input dataset
preserve_dtypes : bool, optional
`True` if we don't want to override dtypes in the current schema, by default False
"""
parents_schema = _combine_schemas(self.parents)
deps_schema = _combine_schemas(self.dependencies)
parents_selector = _combine_selectors(self.parents)
dependencies_selector = _combine_selectors(self.dependencies)
# If parent is an addition or selection node, we may need to
# propagate grouping unless this node already has a selector
if len(self.parents) == 1 and isinstance(self.parents[0].op, (ConcatColumns, SelectionOp)):
parents_selector = self.parents[0].selector
if not self.selector and self.parents[0].selector and (self.parents[0].selector.names):
self.selector = parents_selector
self.input_schema = self.op.compute_input_schema(
root_schema, parents_schema, deps_schema, self.selector
)
self.selector = self.op.compute_selector(
self.input_schema, self.selector, parents_selector, dependencies_selector
)
prev_output_schema = self.output_schema if preserve_dtypes else None
self.output_schema = self.op.compute_output_schema(
self.input_schema, self.selector, prev_output_schema
)
[docs]
def validate_schemas(self, root_schema: Schema, strict_dtypes: bool = False):
"""
Check if this Node's input schema matches the output schemas of parents and dependencies
Parameters
----------
root_schema : Schema
Schema of the input dataset
strict_dtypes : bool, optional
If an error should be raised when column dtypes don't match, by default False
Raises
------
ValueError
If parents and dependencies don't provide an expected column based on
the input schema
ValueError
If the dtype of a column from parents and dependencies doesn't match
the expected dtype based on the input schema
"""
parents_schema = _combine_schemas(self.parents)
deps_schema = _combine_schemas(self.dependencies)
ancestors_schema = root_schema + parents_schema + deps_schema
for col_name, col_schema in self.input_schema.column_schemas.items():
source_col_schema = ancestors_schema.get(col_name)
if not source_col_schema:
raise ValueError(
f"Missing column '{col_name}' at the input to '{self.op.__class__.__name__}'."
)
if strict_dtypes or not self.op.dynamic_dtypes:
if source_col_schema.dtype.without_shape != col_schema.dtype.without_shape:
raise ValueError(
f"Mismatched dtypes for column '{col_name}' provided to "
f"'{self.op.__class__.__name__}': "
f"ancestor nodes provided dtype '{source_col_schema.dtype}', "
f"expected dtype '{col_schema.dtype}'."
)
self.op.validate_schemas(
parents_schema, deps_schema, self.input_schema, self.output_schema, strict_dtypes
)
def __rshift__(self, operator):
"""Transforms this Node by applying an BaseOperator
Parameters
-----------
operators: BaseOperator or callable
Returns
-------
Node
"""
if isinstance(operator, type) and issubclass(operator, BaseOperator):
# handle case where an operator class is passed
operator = operator()
if not isinstance(operator, BaseOperator):
raise ValueError(f"Expected operator or callable, got {operator.__class__}")
child = type(self)()
child.op = operator
child.add_parent(self)
dependencies = operator.dependencies
if dependencies:
if not isinstance(dependencies, collections.abc.Sequence):
dependencies = [dependencies]
for dependency in dependencies:
child.add_dependency(dependency)
return child
def __add__(self, other):
"""Adds columns from this Node with another to return a new Node
Parameters
-----------
other: Node or str or list of str
Returns
-------
Node
"""
if not other:
return self
if isinstance(self.op, ConcatColumns):
child = self
else:
# Create a child node
child = type(self)()
child.op = ConcatColumns(label="+")
child.add_parent(self)
# The right operand becomes a dependency
other_nodes = Node.construct_from(other)
other_nodes = [other_nodes]
for other_node in other_nodes:
# If the other node is a `+` node, we want to collapse it into this `+` node to
# avoid creating a cascade of repeated `+`s that we'd need to optimize out by
# re-combining them later in order to clean up the graph
if not isinstance(other_node, list) and isinstance(other_node.op, ConcatColumns):
child.dependencies += other_node.grouped_parents_with_dependencies
else:
child.add_dependency(other_node)
return child
# handle the "column_name" + Node case
__radd__ = __add__
def __sub__(self, other):
"""Removes columns from this Node with another to return a new Node
Parameters
-----------
other: Node or str or list of str
Columns to remove
Returns
-------
Node
"""
other_nodes = Node.construct_from(other)
if not isinstance(other_nodes, list):
other_nodes = [other_nodes]
child = type(self)()
child.add_parent(self)
child.op = SubtractionOp()
for other_node in other_nodes:
if isinstance(other_node.op, SelectionOp) and not other_node.parents_with_dependencies:
child.selector += other_node.selector
child.op.selector += child.selector
else:
child.add_dependency(other_node)
return child
def __rsub__(self, other):
left_operand = Node.construct_from(other)
right_operand = self
if not isinstance(left_operand, list):
left_operand = [left_operand]
child = type(self)()
child.add_parent(left_operand)
child.op = SubtractionOp()
if (
isinstance(right_operand.op, SelectionOp)
and not right_operand.parents_with_dependencies
):
child.selector += right_operand.selector
child.op.selector += child.selector
else:
child.add_dependency(right_operand)
return child
def __getitem__(self, columns):
"""Selects certain columns from this Node, and returns a new Columngroup with only
those columns
Parameters
-----------
columns: str or list of str
Columns to select
Returns
-------
Node
"""
col_selector = ColumnSelector(columns)
child = type(self)(col_selector)
columns = [columns] if not isinstance(columns, list) else columns
child.op = SubsetColumns(label=str(list(columns)))
child.add_parent(self)
return child
def __repr__(self):
output = " output" if not self.children else ""
return f"<Node {self.label}{output}>"
[docs]
def exportable(self, backend: str = None):
backends = getattr(self.op, "exportable_backends", [])
return hasattr(self.op, "export") and backend in backends
@property
def parents_with_dependencies(self):
nodes = []
for node in self.parents + self.dependencies:
if isinstance(node, list):
nodes.extend(node)
else:
nodes.append(node)
return nodes
@property
def grouped_parents_with_dependencies(self):
return self.parents + self.dependencies
@property
def input_columns(self):
if self.input_schema is None:
raise RuntimeError(
"The input columns aren't computed until the workflow "
"is fit to a dataset or input schema."
)
if (
self.selector
and not self.selector.tags
and all(not selector.tags for selector in self.selector.subgroups)
):
# To maintain column groupings
return self.selector
else:
return ColumnSelector(self.input_schema.column_names)
@property
def output_columns(self):
if self.output_schema is None:
raise RuntimeError(
"The output columns aren't computed until the workflow "
"is fit to a dataset or input schema."
)
return ColumnSelector(self.output_schema.column_names)
@property
def column_mapping(self):
selector = self.selector or ColumnSelector(self.input_schema.column_names)
return self.op.column_mapping(selector)
@property
def dependency_columns(self):
return ColumnSelector(_combine_schemas(self.dependencies).column_names)
@property
def label(self):
if self.op and hasattr(self.op, "label"):
return self.op.label
elif self.op:
return str(type(self.op))
elif not self.parents:
return f"input cols=[{self._cols_repr}]"
else:
return "??"
@property
def _cols_repr(self):
if self.input_schema:
columns = self.input_schema.column_names
elif self.selector:
columns = self.selector.names
else:
columns = []
cols_repr = ", ".join(map(str, columns[:3]))
if len(columns) > 3:
cols_repr += "..."
return cols_repr
@property
def graph(self):
return _to_graphviz(self)
Nodable = Union[
"Node", str, List[str], ColumnSelector, List[Union["Node", str, List[str], ColumnSelector]]
]
[docs]
@classmethod
def construct_from(
cls,
nodable: Nodable,
):
"""
Convert Node-like objects to a Node or list of Nodes.
Parameters
----------
nodable : Nodable
Node-like objects to convert to a Node or list of Nodes.
Returns
-------
Union["Node", List["Node"]]
New Node(s) corresponding to the Node-like input objects
Raises
------
TypeError
If supplied input cannot be converted to a Node or list of Nodes
"""
if isinstance(nodable, str):
return Node(ColumnSelector([nodable]))
if isinstance(nodable, ColumnSelector):
return Node(nodable)
elif isinstance(nodable, Node):
return nodable
elif isinstance(nodable, list):
if all(isinstance(elem, str) for elem in nodable):
return Node(nodable)
else:
nodes = [Node.construct_from(node) for node in nodable]
non_selection_nodes = [
node for node in nodes if not (hasattr(node, "selector") and node.selector)
]
selection_nodes = [
node.selector for node in nodes if (hasattr(node, "selector") and node.selector)
]
selection_nodes = (
[Node(_combine_selectors(selection_nodes))] if selection_nodes else []
)
return non_selection_nodes + selection_nodes
else:
raise TypeError(
"Unsupported type: Cannot convert object " f"of type {type(nodable)} to Node."
)
def iter_nodes(nodes):
queue = nodes[:]
while queue:
current = queue.pop()
if isinstance(current, list):
queue.extend(current)
else:
yield current
for node in current.parents_with_dependencies:
if node not in queue:
queue.append(node)
# output node (bottom) -> selection leaf nodes (top)
def preorder_iter_nodes(nodes):
queue = []
if not isinstance(nodes, list):
nodes = [nodes]
def traverse(current_nodes):
for node in current_nodes:
# Avoid creating duplicate nodes in the queue
if node in queue:
queue.remove(node)
queue.append(node)
for node in current_nodes:
traverse(node.parents_with_dependencies)
traverse(nodes)
for node in queue:
yield node
# selection leaf nodes (top) -> output node (bottom)
def postorder_iter_nodes(nodes):
queue = []
if not isinstance(nodes, list):
nodes = [nodes]
def traverse(current_nodes):
for node in current_nodes:
traverse(node.parents_with_dependencies)
if node not in queue:
queue.append(node)
traverse(nodes)
for node in queue:
yield node
def _filter_by_type(elements, type_):
results = []
for elem in elements:
if isinstance(elem, type_):
results.append(elem)
elif isinstance(elem, list):
results += _filter_by_type(elem, type_)
return results
def _combine_schemas(elements):
combined = Schema()
for elem in elements:
if isinstance(elem, Node):
combined += elem.output_schema
elif isinstance(elem, ColumnSelector):
combined += Schema(elem.names)
elif isinstance(elem, list):
combined += _combine_schemas(elem)
return combined
def _combine_selectors(elements):
combined = ColumnSelector()
for elem in elements:
if isinstance(elem, Node):
if elem.selector:
selector = elem.op.output_column_names(elem.selector)
elif elem.output_schema:
selector = ColumnSelector(elem.output_schema.column_names)
elif elem.input_schema:
selector = ColumnSelector(elem.input_schema.column_names)
selector = elem.op.output_column_names(selector)
else:
selector = ColumnSelector()
combined += selector
elif isinstance(elem, ColumnSelector):
combined += elem
elif isinstance(elem, str):
combined += ColumnSelector(elem)
elif isinstance(elem, list):
combined += ColumnSelector(subgroups=_combine_selectors(elem))
return combined
def _to_selector(value):
if not isinstance(value, (ColumnSelector, Node)):
return ColumnSelector(value)
else:
return value
def _strs_to_selectors(elements):
return [_to_selector(elem) for elem in elements]
def _to_graphviz(output_node):
"""Converts a Node to a GraphViz DiGraph object useful for display in notebooks"""
from graphviz import Digraph
graph = Digraph()
# get all the nodes from parents of this columngroup
# and add edges between each of them
allnodes = list(set(iter_nodes([output_node])))
node_ids = {v: str(k) for k, v in enumerate(allnodes)}
for node, nodeid in node_ids.items():
graph.node(nodeid, node.label)
for parent in node.parents_with_dependencies:
graph.edge(node_ids[parent], nodeid)
if node.selector and node.selector.names:
selector_id = f"{nodeid}_selector"
graph.node(selector_id, str(node.selector.names))
graph.edge(selector_id, nodeid)
# add a single node representing the final state
final_node_id = str(len(allnodes))
final_string = "output cols"
if output_node._cols_repr:
final_string += f"=[{output_node._cols_repr}]"
graph.node(final_node_id, final_string)
graph.edge(node_ids[output_node], final_node_id)
return graph
def _convert_col(col):
if isinstance(col, (str, tuple)):
return col
elif isinstance(col, list):
return tuple(col)
else:
raise ValueError(f"Invalid column value for Node: {col}")
def _derived_output_cols(input_cols, column_mapping):
outputs = []
for input_col in set(input_cols):
for output_col_name, input_col_list in column_mapping.items():
if input_col in input_col_list:
outputs.append(output_col_name)
return outputs