#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import collections.abc
import os
from typing import List, Union
from merlin.dag.operator import Operator
from merlin.dag.ops import ConcatColumns, GroupingOp, SelectionOp, SubsetColumns, SubtractionOp
from merlin.dag.ops.udf import UDF
from merlin.dag.selector import ColumnSelector
from merlin.schema import Schema
Nodable = Union[
"Node",
Operator,
str,
List[str],
ColumnSelector,
List[Union["Node", Operator, str, List[str], ColumnSelector]],
]
[docs]
class Node:
"""A Node is a group of columns that you want to apply the same transformations to.
Node's can be transformed by shifting operators on to them, which returns a new
Node with the transformations applied. This lets you define a graph of operations
that makes up your workflow
Parameters
----------
selector: ColumnSelector
Defines which columns to select from the input Dataset using column names and tags.
"""
[docs]
def __init__(self, selector=None):
self.parents = []
self.children = []
self.dependencies = []
self.op = None
self.input_schema = None
self.output_schema = None
if isinstance(selector, list):
selector = ColumnSelector(selector)
if selector and not isinstance(selector, ColumnSelector):
raise TypeError("The selector argument must be a list or a ColumnSelector")
if selector is not None:
self.op = SelectionOp(selector)
self.selector = selector
@property
def selector(self):
return self._selector
@selector.setter
def selector(self, sel):
if isinstance(sel, list):
sel = ColumnSelector(sel)
self._selector = sel
# These methods must maintain grouping
[docs]
def add_dependency(
self,
dep: Nodable,
):
"""
Adding a dependency node to this node
Parameters
----------
dep : Union[str, ColumnSelector, Node, List[Union[str, Node, ColumnSelector]]]
Dependency to be added
"""
dep_node = Node.construct_from(dep)
if not isinstance(dep_node, list):
dep_nodes = [dep_node]
else:
dep_nodes = dep_node
for node in dep_nodes:
node.children.append(self)
self.dependencies.append(dep_node)
[docs]
def add_parent(
self,
parent: Nodable,
):
"""
Adding a parent node to this node
Parameters
----------
parent : Union[str, ColumnSelector, Node, List[Union[str, Node, ColumnSelector]]]
Parent to be added
"""
parent_nodes = Node.construct_from(parent)
if not isinstance(parent_nodes, list):
parent_nodes = [parent_nodes]
for parent_node in parent_nodes:
parent_node.children.append(self)
self.parents.extend(parent_nodes)
[docs]
def add_child(
self,
child: Nodable,
):
"""
Adding a child node to this node
Parameters
----------
child : Union[str, ColumnSelector, Node, List[Union[str, Node, ColumnSelector]]]
Child to be added
"""
child_nodes = Node.construct_from(child)
if not isinstance(child_nodes, list):
child_nodes = [child_nodes]
for child_node in child_nodes:
child_node.parents.append(self)
self.children.extend(child_nodes)
[docs]
def remove_child(
self,
child: Nodable,
):
"""
Removing a child node from this node
Parameters
----------
child : Union[str, ColumnSelector, Node, List[Union[str, Node, ColumnSelector]]]
Child to be removed
"""
child_nodes = Node.construct_from(child)
if not isinstance(child_nodes, list):
child_nodes = [child_nodes]
for child_node in child_nodes:
if self in child_node.parents:
child_node.parents.remove(self)
if child_node in self.children:
self.children.remove(child_node)
[docs]
def compute_schemas(self, root_schema: Schema, preserve_dtypes: bool = False):
"""
Defines the input and output schema
Parameters
----------
root_schema : Schema
Schema of the input dataset
preserve_dtypes : bool, optional
`True` if we don't want to override dtypes in the current schema, by default False
"""
parents_schema = _combine_schemas(self.parents)
deps_schema = _combine_schemas(self.dependencies)
parents_selector = _combine_selectors(self.parents)
dependencies_selector = _combine_selectors(self.dependencies)
# If parent is an addition or selection node, we may need to
# propagate grouping unless this node already has a selector
if len(self.parents) == 1 and isinstance(self.parents[0].op, (ConcatColumns, SelectionOp)):
parents_selector = self.parents[0].selector
if not self.selector and self.parents[0].selector and (self.parents[0].selector.names):
self.selector = parents_selector
self.input_schema = self.op.compute_input_schema(
root_schema, parents_schema, deps_schema, self.selector
)
self.selector = self.op.compute_selector(
self.input_schema, self.selector, parents_selector, dependencies_selector
)
prev_output_schema = self.output_schema if preserve_dtypes else None
self.output_schema = self.op.compute_output_schema(
self.input_schema, self.selector, prev_output_schema
)
[docs]
def validate_schemas(self, root_schema: Schema, strict_dtypes: bool = False):
"""
Check if this Node's input schema matches the output schemas of parents and dependencies
Parameters
----------
root_schema : Schema
Schema of the input dataset
strict_dtypes : bool, optional
If an error should be raised when column dtypes don't match, by default False
Raises
------
ValueError
If parents and dependencies don't provide an expected column based on
the input schema
ValueError
If the dtype of a column from parents and dependencies doesn't match
the expected dtype based on the input schema
"""
parents_schema = _combine_schemas(self.parents)
deps_schema = _combine_schemas(self.dependencies)
ancestors_schema = root_schema + parents_schema + deps_schema
for col_name, col_schema in self.input_schema.column_schemas.items():
source_col_schema = ancestors_schema.get(col_name)
if not source_col_schema:
raise ValueError(
f"Missing column '{col_name}' at the input to '{self.op.__class__.__name__}'."
)
if strict_dtypes or not self.op.dynamic_dtypes:
if source_col_schema.dtype.without_shape != col_schema.dtype.without_shape:
raise ValueError(
f"Mismatched dtypes for column '{col_name}' provided to "
f"'{self.op.__class__.__name__}': "
f"ancestor nodes provided dtype '{source_col_schema.dtype}', "
f"expected dtype '{col_schema.dtype}'."
)
self.op.validate_schemas(
parents_schema,
deps_schema,
self.input_schema,
self.output_schema,
strict_dtypes,
)
def __rshift__(self, operator):
"""Transforms this Node by applying an Operator
Parameters
-----------
operators: Operator
or callable
Returns
-------
Node
"""
if callable(operator) and not (
isinstance(operator, type) and issubclass(operator, Operator)
):
# implicit lambdaop conversion.
operator = UDF(operator)
if isinstance(operator, type) and issubclass(operator, Operator):
# handle case where an operator class is passed
operator = operator()
if not isinstance(operator, Operator):
raise ValueError(f"Expected operator or callable, got {operator.__class__}")
child = type(self)()
child.op = operator
child.add_parent(self)
dependencies = operator.dependencies
if dependencies:
if not isinstance(dependencies, collections.abc.Sequence):
dependencies = [dependencies]
for dependency in dependencies:
child.add_dependency(dependency)
return child
def __add__(self, other):
"""Adds columns from this Node with another to return a new Node
Parameters
-----------
other: Node or str or list of str
Returns
-------
Node
"""
if not other:
return self
if isinstance(self.op, ConcatColumns):
child = self
else:
# Create a child node
child = type(self)()
child.op = ConcatColumns(label="+")
child.add_parent(self)
# The right operand becomes a dependency
other_nodes = Node.construct_from(other)
other_nodes = [other_nodes]
for other_node in other_nodes:
# If the other node is a `[]`, we want to maintain grouping
# so create a selection node that we can use to do that
if isinstance(other_node, list):
grouped_node = Node.construct_from(GroupingOp())
for node in other_node:
grouped_node.add_parent(node)
child.add_dependency(grouped_node)
# If the other node is a `+` node, we want to collapse it into this `+` node to
# avoid creating a cascade of repeated `+`s that we'd need to optimize out by
# re-combining them later in order to clean up the graph
elif isinstance(other_node.op, ConcatColumns):
child.dependencies += other_node.grouped_parents_with_dependencies
else:
child.add_dependency(other_node)
return child
def __radd__(self, other):
other_node = Node.construct_from(other)
return other_node.__add__(self)
def __sub__(self, other):
"""Removes columns from this Node with another to return a new Node
Parameters
-----------
other: Node or str or list of str
Columns to remove
Returns
-------
Node
"""
other_nodes = Node.construct_from(other)
if not isinstance(other_nodes, list):
other_nodes = [other_nodes]
child = type(self)()
child.add_parent(self)
child.op = SubtractionOp()
for other_node in other_nodes:
if isinstance(other_node.op, SelectionOp) and not other_node.parents_with_dependencies:
child.selector += other_node.selector
child.op.selector += child.selector
else:
child.add_dependency(other_node)
return child
def __rsub__(self, other):
left_operand = Node.construct_from(other)
right_operand = self
if not isinstance(left_operand, list):
left_operand = [left_operand]
child = type(self)()
child.add_parent(left_operand)
child.op = SubtractionOp()
if (
isinstance(right_operand.op, SelectionOp)
and not right_operand.parents_with_dependencies
):
child.selector += right_operand.selector
child.op.selector += child.selector
else:
child.add_dependency(right_operand)
return child
def __getitem__(self, columns):
"""Selects certain columns from this Node, and returns a new Columngroup with only
those columns
Parameters
-----------
columns: str or list of str
Columns to select
Returns
-------
Node
"""
col_selector = ColumnSelector(columns)
child = type(self)(col_selector)
columns = [columns] if not isinstance(columns, list) else columns
child.op = SubsetColumns(label=str(list(columns)))
child.add_parent(self)
return child
def __repr__(self):
output = " output" if not self.children else ""
return f"<Node {self.label}{output}>"
[docs]
def exportable(self, backend: str = None):
backends = getattr(self.op, "exportable_backends", [])
return hasattr(self.op, "export") and backend in backends
[docs]
def export(
self,
output_path: Union[str, os.PathLike],
node_id: int = None,
version: int = 1,
):
"""
Export a directory for this node, containing the required artifacts
to run in the target context.
Parameters
----------
output_path : Union[str, os.PathLike]
The base path to write this node's export directory.
node_id : int, optional
The id of this node in a larger graph (for disambiguation), by default None.
version : int, optional
The version of the node to use for this export, by default 1.
"""
return self.op.export(
output_path,
self.input_schema,
self.output_schema,
node_id=node_id,
version=version,
)
@property
def export_name(self):
"""
Name for the exported node directory.
Returns
-------
str
Name supplied by this node's operator.
"""
return self.op.export_name
@property
def parents_with_dependencies(self):
nodes = []
for node in self.parents + self.dependencies:
if isinstance(node, list):
nodes.extend(node)
else:
nodes.append(node)
return nodes
@property
def grouped_parents_with_dependencies(self):
return self.parents + self.dependencies
@property
def input_columns(self):
if self.input_schema is None:
raise RuntimeError(
"The input columns aren't computed until the workflow "
"is fit to a dataset or input schema."
)
if (
self.selector
and not self.selector.tags
and all(not selector.tags for selector in self.selector.subgroups)
):
# To maintain column groupings
return self.selector
else:
return ColumnSelector(self.input_schema.column_names)
@property
def output_columns(self):
if self.output_schema is None:
raise RuntimeError(
"The output columns aren't computed until the workflow "
"is fit to a dataset or input schema."
)
return ColumnSelector(self.output_schema.column_names)
@property
def column_mapping(self):
selector = self.selector or ColumnSelector(self.input_schema.column_names)
return self.op.column_mapping(selector)
@property
def dependency_columns(self):
return ColumnSelector(_combine_schemas(self.dependencies).column_names)
@property
def label(self):
if self.op and hasattr(self.op, "label"):
return self.op.label
elif self.op:
return str(type(self.op))
elif not self.parents:
return f"input cols=[{self._cols_repr}]"
else:
return "??"
@property
def _cols_repr(self):
if self.input_schema:
columns = self.input_schema.column_names
elif self.selector:
columns = self.selector.names
else:
columns = []
cols_repr = ", ".join(map(str, columns[:3]))
if len(columns) > 3:
cols_repr += "..."
return cols_repr
@property
def graph(self):
return _to_graphviz(self)
[docs]
@classmethod
def construct_from(
cls,
nodable: Nodable,
):
"""
Convert Node-like objects to a Node or list of Nodes.
Parameters
----------
nodable : Nodable
Node-like objects to convert to a Node or list of Nodes.
Returns
-------
Union["Node", List["Node"]]
New Node(s) corresponding to the Node-like input objects
Raises
------
TypeError
If supplied input cannot be converted to a Node or list of Nodes
"""
if isinstance(nodable, str):
return Node(ColumnSelector([nodable]))
if isinstance(nodable, ColumnSelector):
return Node(nodable)
elif isinstance(nodable, Operator):
node = Node()
node.op = nodable
return node
elif isinstance(nodable, Node):
return nodable
elif isinstance(nodable, list):
if all(isinstance(elem, str) for elem in nodable):
return Node(nodable)
else:
nodes = [Node.construct_from(node) for node in nodable]
non_selection_nodes = [
node for node in nodes if not (hasattr(node, "selector") and node.selector)
]
selection_nodes = [
node.selector for node in nodes if (hasattr(node, "selector") and node.selector)
]
selection_nodes = (
[Node(_combine_selectors(selection_nodes))] if selection_nodes else []
)
group_node = Node.construct_from(GroupingOp())
all_node = non_selection_nodes + selection_nodes
for node in all_node:
group_node.add_parent(node)
return group_node
else:
raise TypeError(
"Unsupported type: Cannot convert object " f"of type {type(nodable)} to Node."
)
def iter_nodes(nodes, flatten_subgraphs=False):
queue = nodes[:]
while queue:
current = queue.pop(0)
if flatten_subgraphs and current.op.is_subgraph:
new_nodes = iter_nodes([current.op.graph.output_node])
for node in new_nodes:
if node not in queue:
queue.append(node)
if isinstance(current, list):
queue.extend(current)
else:
yield current
for node in current.parents_with_dependencies:
if node not in queue:
queue.append(node)
# output node (bottom) -> selection leaf nodes (top)
def preorder_iter_nodes(nodes, flatten_subgraphs=False):
queue = []
if not isinstance(nodes, list):
nodes = [nodes]
def traverse(current_nodes):
for node in current_nodes:
# Avoid creating duplicate nodes in the queue
if node in queue:
queue.remove(node)
if flatten_subgraphs and node.op.is_subgraph:
queue.extend(list(preorder_iter_nodes(node.op.graph.output_node)))
queue.append(node)
for node in current_nodes:
traverse(node.parents_with_dependencies)
traverse(nodes)
for node in queue:
yield node
# selection leaf nodes (top) -> output node (bottom)
def postorder_iter_nodes(nodes, flatten_subgraphs=False):
queue = []
if not isinstance(nodes, list):
nodes = [nodes]
def traverse(current_nodes):
for node in current_nodes:
traverse(node.parents_with_dependencies)
if node not in queue:
queue.append(node)
if flatten_subgraphs and node.op.is_subgraph:
queue.extend(list(postorder_iter_nodes(node.op.graph.output_node)))
traverse(nodes)
for node in queue:
yield node
def _filter_by_type(elements, type_):
results = []
for elem in elements:
if isinstance(elem, type_):
results.append(elem)
elif isinstance(elem, list):
results += _filter_by_type(elem, type_)
return results
def _combine_schemas(elements, input_schemas=False):
combined = Schema()
for elem in elements:
if isinstance(elem, Node):
if input_schemas:
combined += elem.input_schema
else:
combined += elem.output_schema
elif isinstance(elem, ColumnSelector):
combined += Schema(elem.names)
elif isinstance(elem, list):
combined += _combine_schemas(elem, input_schemas=input_schemas)
return combined
def _combine_selectors(elements):
combined = ColumnSelector()
for elem in elements:
if isinstance(elem, Node):
if isinstance(elem.op, GroupingOp):
selector = elem.selector
elif elem.selector:
selector = elem.op.output_column_names(elem.selector)
elif elem.output_schema:
selector = ColumnSelector(elem.output_schema.column_names)
elif elem.input_schema:
selector = ColumnSelector(elem.input_schema.column_names)
selector = elem.op.output_column_names(selector)
else:
selector = ColumnSelector()
combined += selector
elif isinstance(elem, ColumnSelector):
combined += elem
elif isinstance(elem, str):
combined += ColumnSelector(elem)
return combined
def _to_selector(value):
if not isinstance(value, (ColumnSelector, Node)):
return ColumnSelector(value)
else:
return value
def _strs_to_selectors(elements):
return [_to_selector(elem) for elem in elements]
def _to_graphviz(output_node):
"""Converts a Node to a GraphViz DiGraph object useful for display in notebooks"""
from graphviz import Digraph
graph = Digraph()
# get all the nodes from parents of this columngroup
# and add edges between each of them
allnodes = list(set(iter_nodes([output_node])))
node_ids = {v: str(k) for k, v in enumerate(allnodes)}
for node, nodeid in node_ids.items():
graph.node(nodeid, node.label)
for parent in node.parents_with_dependencies:
graph.edge(node_ids[parent], nodeid)
if node.selector and node.selector.names:
selector_id = f"{nodeid}_selector"
graph.node(selector_id, str(node.selector.names))
graph.edge(selector_id, nodeid)
# add a single node representing the final state
final_node_id = str(len(allnodes))
final_string = "output cols"
if output_node._cols_repr:
final_string += f"=[{output_node._cols_repr}]"
graph.node(final_node_id, final_string)
graph.edge(node_ids[output_node], final_node_id)
return graph
def _convert_col(col):
if isinstance(col, (str, tuple)):
return col
elif isinstance(col, list):
return tuple(col)
else:
raise ValueError(f"Invalid column value for Node: {col}")
def _derived_output_cols(input_cols, column_mapping):
outputs = []
for input_col in set(input_cols):
for output_col_name, input_col_list in column_mapping.items():
if input_col in input_col_list:
outputs.append(output_col_name)
return outputs