Source code for merlin.dag.node

# Copyright (c) 2022, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List, Union

from merlin.dag.operator import Operator
from merlin.dag.ops import ConcatColumns, GroupingOp, SelectionOp, SubsetColumns, SubtractionOp
from merlin.dag.ops.udf import UDF
from merlin.dag.selector import ColumnSelector
from merlin.schema import Schema

Nodable = Union[
    List[Union["Node", Operator, str, List[str], ColumnSelector]],

[docs] class Node: """A Node is a group of columns that you want to apply the same transformations to. Node's can be transformed by shifting operators on to them, which returns a new Node with the transformations applied. This lets you define a graph of operations that makes up your workflow Parameters ---------- selector: ColumnSelector Defines which columns to select from the input Dataset using column names and tags. """
[docs] def __init__(self, selector=None): self.parents = [] self.children = [] self.dependencies = [] self.op = None self.input_schema = None self.output_schema = None if isinstance(selector, list): selector = ColumnSelector(selector) if selector and not isinstance(selector, ColumnSelector): raise TypeError("The selector argument must be a list or a ColumnSelector") if selector is not None: self.op = SelectionOp(selector) self.selector = selector
@property def selector(self): return self._selector @selector.setter def selector(self, sel): if isinstance(sel, list): sel = ColumnSelector(sel) self._selector = sel # These methods must maintain grouping
[docs] def add_dependency( self, dep: Nodable, ): """ Adding a dependency node to this node Parameters ---------- dep : Union[str, ColumnSelector, Node, List[Union[str, Node, ColumnSelector]]] Dependency to be added """ dep_node = Node.construct_from(dep) if not isinstance(dep_node, list): dep_nodes = [dep_node] else: dep_nodes = dep_node for node in dep_nodes: node.children.append(self) self.dependencies.append(dep_node)
[docs] def add_parent( self, parent: Nodable, ): """ Adding a parent node to this node Parameters ---------- parent : Union[str, ColumnSelector, Node, List[Union[str, Node, ColumnSelector]]] Parent to be added """ parent_nodes = Node.construct_from(parent) if not isinstance(parent_nodes, list): parent_nodes = [parent_nodes] for parent_node in parent_nodes: parent_node.children.append(self) self.parents.extend(parent_nodes)
[docs] def add_child( self, child: Nodable, ): """ Adding a child node to this node Parameters ---------- child : Union[str, ColumnSelector, Node, List[Union[str, Node, ColumnSelector]]] Child to be added """ child_nodes = Node.construct_from(child) if not isinstance(child_nodes, list): child_nodes = [child_nodes] for child_node in child_nodes: child_node.parents.append(self) self.children.extend(child_nodes)
[docs] def remove_child( self, child: Nodable, ): """ Removing a child node from this node Parameters ---------- child : Union[str, ColumnSelector, Node, List[Union[str, Node, ColumnSelector]]] Child to be removed """ child_nodes = Node.construct_from(child) if not isinstance(child_nodes, list): child_nodes = [child_nodes] for child_node in child_nodes: if self in child_node.parents: child_node.parents.remove(self) if child_node in self.children: self.children.remove(child_node)
[docs] def compute_schemas(self, root_schema: Schema, preserve_dtypes: bool = False): """ Defines the input and output schema Parameters ---------- root_schema : Schema Schema of the input dataset preserve_dtypes : bool, optional `True` if we don't want to override dtypes in the current schema, by default False """ parents_schema = _combine_schemas(self.parents) deps_schema = _combine_schemas(self.dependencies) parents_selector = _combine_selectors(self.parents) dependencies_selector = _combine_selectors(self.dependencies) # If parent is an addition or selection node, we may need to # propagate grouping unless this node already has a selector if len(self.parents) == 1 and isinstance(self.parents[0].op, (ConcatColumns, SelectionOp)): parents_selector = self.parents[0].selector if not self.selector and self.parents[0].selector and (self.parents[0].selector.names): self.selector = parents_selector self.input_schema = self.op.compute_input_schema( root_schema, parents_schema, deps_schema, self.selector ) self.selector = self.op.compute_selector( self.input_schema, self.selector, parents_selector, dependencies_selector ) prev_output_schema = self.output_schema if preserve_dtypes else None self.output_schema = self.op.compute_output_schema( self.input_schema, self.selector, prev_output_schema )
[docs] def validate_schemas(self, root_schema: Schema, strict_dtypes: bool = False): """ Check if this Node's input schema matches the output schemas of parents and dependencies Parameters ---------- root_schema : Schema Schema of the input dataset strict_dtypes : bool, optional If an error should be raised when column dtypes don't match, by default False Raises ------ ValueError If parents and dependencies don't provide an expected column based on the input schema ValueError If the dtype of a column from parents and dependencies doesn't match the expected dtype based on the input schema """ parents_schema = _combine_schemas(self.parents) deps_schema = _combine_schemas(self.dependencies) ancestors_schema = root_schema + parents_schema + deps_schema for col_name, col_schema in self.input_schema.column_schemas.items(): source_col_schema = ancestors_schema.get(col_name) if not source_col_schema: raise ValueError( f"Missing column '{col_name}' at the input to '{self.op.__class__.__name__}'." ) if strict_dtypes or not self.op.dynamic_dtypes: if source_col_schema.dtype.without_shape != col_schema.dtype.without_shape: raise ValueError( f"Mismatched dtypes for column '{col_name}' provided to " f"'{self.op.__class__.__name__}': " f"ancestor nodes provided dtype '{source_col_schema.dtype}', " f"expected dtype '{col_schema.dtype}'." ) self.op.validate_schemas( parents_schema, deps_schema, self.input_schema, self.output_schema, strict_dtypes, )
def __rshift__(self, operator): """Transforms this Node by applying an Operator Parameters ----------- operators: Operator or callable Returns ------- Node """ if callable(operator) and not ( isinstance(operator, type) and issubclass(operator, Operator) ): # implicit lambdaop conversion. operator = UDF(operator) if isinstance(operator, type) and issubclass(operator, Operator): # handle case where an operator class is passed operator = operator() if not isinstance(operator, Operator): raise ValueError(f"Expected operator or callable, got {operator.__class__}") child = type(self)() child.op = operator child.add_parent(self) dependencies = operator.dependencies if dependencies: if not isinstance(dependencies, dependencies = [dependencies] for dependency in dependencies: child.add_dependency(dependency) return child def __add__(self, other): """Adds columns from this Node with another to return a new Node Parameters ----------- other: Node or str or list of str Returns ------- Node """ if not other: return self if isinstance(self.op, ConcatColumns): child = self else: # Create a child node child = type(self)() child.op = ConcatColumns(label="+") child.add_parent(self) # The right operand becomes a dependency other_nodes = Node.construct_from(other) other_nodes = [other_nodes] for other_node in other_nodes: # If the other node is a `[]`, we want to maintain grouping # so create a selection node that we can use to do that if isinstance(other_node, list): grouped_node = Node.construct_from(GroupingOp()) for node in other_node: grouped_node.add_parent(node) child.add_dependency(grouped_node) # If the other node is a `+` node, we want to collapse it into this `+` node to # avoid creating a cascade of repeated `+`s that we'd need to optimize out by # re-combining them later in order to clean up the graph elif isinstance(other_node.op, ConcatColumns): child.dependencies += other_node.grouped_parents_with_dependencies else: child.add_dependency(other_node) return child def __radd__(self, other): other_node = Node.construct_from(other) return other_node.__add__(self) def __sub__(self, other): """Removes columns from this Node with another to return a new Node Parameters ----------- other: Node or str or list of str Columns to remove Returns ------- Node """ other_nodes = Node.construct_from(other) if not isinstance(other_nodes, list): other_nodes = [other_nodes] child = type(self)() child.add_parent(self) child.op = SubtractionOp() for other_node in other_nodes: if isinstance(other_node.op, SelectionOp) and not other_node.parents_with_dependencies: child.selector += other_node.selector child.op.selector += child.selector else: child.add_dependency(other_node) return child def __rsub__(self, other): left_operand = Node.construct_from(other) right_operand = self if not isinstance(left_operand, list): left_operand = [left_operand] child = type(self)() child.add_parent(left_operand) child.op = SubtractionOp() if ( isinstance(right_operand.op, SelectionOp) and not right_operand.parents_with_dependencies ): child.selector += right_operand.selector child.op.selector += child.selector else: child.add_dependency(right_operand) return child def __getitem__(self, columns): """Selects certain columns from this Node, and returns a new Columngroup with only those columns Parameters ----------- columns: str or list of str Columns to select Returns ------- Node """ col_selector = ColumnSelector(columns) child = type(self)(col_selector) columns = [columns] if not isinstance(columns, list) else columns child.op = SubsetColumns(label=str(list(columns))) child.add_parent(self) return child def __repr__(self): output = " output" if not self.children else "" return f"<Node {self.label}{output}>"
[docs] def remove_inputs(self, input_cols: List[str]) -> List[str]: """ Remove input columns and all output columns that depend on them. Parameters ---------- input_cols : List[str] The input columns to remove Returns ------- List[str] The output columns that were removed """ removed_outputs = _derived_output_cols(input_cols, self.column_mapping) self.input_schema = self.input_schema.without(input_cols) self.output_schema = self.output_schema.without(removed_outputs) if self.selector: self.selector = self.selector.filter_columns(ColumnSelector(input_cols)) return removed_outputs
[docs] def exportable(self, backend: str = None): backends = getattr(self.op, "exportable_backends", []) return hasattr(self.op, "export") and backend in backends
[docs] def export( self, output_path: Union[str, os.PathLike], node_id: int = None, version: int = 1, ): """ Export a directory for this node, containing the required artifacts to run in the target context. Parameters ---------- output_path : Union[str, os.PathLike] The base path to write this node's export directory. node_id : int, optional The id of this node in a larger graph (for disambiguation), by default None. version : int, optional The version of the node to use for this export, by default 1. """ return self.op.export( output_path, self.input_schema, self.output_schema, node_id=node_id, version=version, )
@property def export_name(self): """ Name for the exported node directory. Returns ------- str Name supplied by this node's operator. """ return self.op.export_name @property def parents_with_dependencies(self): nodes = [] for node in self.parents + self.dependencies: if isinstance(node, list): nodes.extend(node) else: nodes.append(node) return nodes @property def grouped_parents_with_dependencies(self): return self.parents + self.dependencies @property def input_columns(self): if self.input_schema is None: raise RuntimeError( "The input columns aren't computed until the workflow " "is fit to a dataset or input schema." ) if ( self.selector and not self.selector.tags and all(not selector.tags for selector in self.selector.subgroups) ): # To maintain column groupings return self.selector else: return ColumnSelector(self.input_schema.column_names) @property def output_columns(self): if self.output_schema is None: raise RuntimeError( "The output columns aren't computed until the workflow " "is fit to a dataset or input schema." ) return ColumnSelector(self.output_schema.column_names) @property def column_mapping(self): selector = self.selector or ColumnSelector(self.input_schema.column_names) return self.op.column_mapping(selector) @property def dependency_columns(self): return ColumnSelector(_combine_schemas(self.dependencies).column_names) @property def label(self): if self.op and hasattr(self.op, "label"): return self.op.label elif self.op: return str(type(self.op)) elif not self.parents: return f"input cols=[{self._cols_repr}]" else: return "??" @property def _cols_repr(self): if self.input_schema: columns = self.input_schema.column_names elif self.selector: columns = self.selector.names else: columns = [] cols_repr = ", ".join(map(str, columns[:3])) if len(columns) > 3: cols_repr += "..." return cols_repr @property def graph(self): return _to_graphviz(self)
[docs] @classmethod def construct_from( cls, nodable: Nodable, ): """ Convert Node-like objects to a Node or list of Nodes. Parameters ---------- nodable : Nodable Node-like objects to convert to a Node or list of Nodes. Returns ------- Union["Node", List["Node"]] New Node(s) corresponding to the Node-like input objects Raises ------ TypeError If supplied input cannot be converted to a Node or list of Nodes """ if isinstance(nodable, str): return Node(ColumnSelector([nodable])) if isinstance(nodable, ColumnSelector): return Node(nodable) elif isinstance(nodable, Operator): node = Node() node.op = nodable return node elif isinstance(nodable, Node): return nodable elif isinstance(nodable, list): if all(isinstance(elem, str) for elem in nodable): return Node(nodable) else: nodes = [Node.construct_from(node) for node in nodable] non_selection_nodes = [ node for node in nodes if not (hasattr(node, "selector") and node.selector) ] selection_nodes = [ node.selector for node in nodes if (hasattr(node, "selector") and node.selector) ] selection_nodes = ( [Node(_combine_selectors(selection_nodes))] if selection_nodes else [] ) group_node = Node.construct_from(GroupingOp()) all_node = non_selection_nodes + selection_nodes for node in all_node: group_node.add_parent(node) return group_node else: raise TypeError( "Unsupported type: Cannot convert object " f"of type {type(nodable)} to Node." )
def iter_nodes(nodes, flatten_subgraphs=False): queue = nodes[:] while queue: current = queue.pop(0) if flatten_subgraphs and current.op.is_subgraph: new_nodes = iter_nodes([current.op.graph.output_node]) for node in new_nodes: if node not in queue: queue.append(node) if isinstance(current, list): queue.extend(current) else: yield current for node in current.parents_with_dependencies: if node not in queue: queue.append(node) # output node (bottom) -> selection leaf nodes (top) def preorder_iter_nodes(nodes, flatten_subgraphs=False): queue = [] if not isinstance(nodes, list): nodes = [nodes] def traverse(current_nodes): for node in current_nodes: # Avoid creating duplicate nodes in the queue if node in queue: queue.remove(node) if flatten_subgraphs and node.op.is_subgraph: queue.extend(list(preorder_iter_nodes(node.op.graph.output_node))) queue.append(node) for node in current_nodes: traverse(node.parents_with_dependencies) traverse(nodes) for node in queue: yield node # selection leaf nodes (top) -> output node (bottom) def postorder_iter_nodes(nodes, flatten_subgraphs=False): queue = [] if not isinstance(nodes, list): nodes = [nodes] def traverse(current_nodes): for node in current_nodes: traverse(node.parents_with_dependencies) if node not in queue: queue.append(node) if flatten_subgraphs and node.op.is_subgraph: queue.extend(list(postorder_iter_nodes(node.op.graph.output_node))) traverse(nodes) for node in queue: yield node def _filter_by_type(elements, type_): results = [] for elem in elements: if isinstance(elem, type_): results.append(elem) elif isinstance(elem, list): results += _filter_by_type(elem, type_) return results def _combine_schemas(elements, input_schemas=False): combined = Schema() for elem in elements: if isinstance(elem, Node): if input_schemas: combined += elem.input_schema else: combined += elem.output_schema elif isinstance(elem, ColumnSelector): combined += Schema(elem.names) elif isinstance(elem, list): combined += _combine_schemas(elem, input_schemas=input_schemas) return combined def _combine_selectors(elements): combined = ColumnSelector() for elem in elements: if isinstance(elem, Node): if isinstance(elem.op, GroupingOp): selector = elem.selector elif elem.selector: selector = elem.op.output_column_names(elem.selector) elif elem.output_schema: selector = ColumnSelector(elem.output_schema.column_names) elif elem.input_schema: selector = ColumnSelector(elem.input_schema.column_names) selector = elem.op.output_column_names(selector) else: selector = ColumnSelector() combined += selector elif isinstance(elem, ColumnSelector): combined += elem elif isinstance(elem, str): combined += ColumnSelector(elem) return combined def _to_selector(value): if not isinstance(value, (ColumnSelector, Node)): return ColumnSelector(value) else: return value def _strs_to_selectors(elements): return [_to_selector(elem) for elem in elements] def _to_graphviz(output_node): """Converts a Node to a GraphViz DiGraph object useful for display in notebooks""" from graphviz import Digraph graph = Digraph() # get all the nodes from parents of this columngroup # and add edges between each of them allnodes = list(set(iter_nodes([output_node]))) node_ids = {v: str(k) for k, v in enumerate(allnodes)} for node, nodeid in node_ids.items(): graph.node(nodeid, node.label) for parent in node.parents_with_dependencies: graph.edge(node_ids[parent], nodeid) if node.selector and node.selector.names: selector_id = f"{nodeid}_selector" graph.node(selector_id, str(node.selector.names)) graph.edge(selector_id, nodeid) # add a single node representing the final state final_node_id = str(len(allnodes)) final_string = "output cols" if output_node._cols_repr: final_string += f"=[{output_node._cols_repr}]" graph.node(final_node_id, final_string) graph.edge(node_ids[output_node], final_node_id) return graph def _convert_col(col): if isinstance(col, (str, tuple)): return col elif isinstance(col, list): return tuple(col) else: raise ValueError(f"Invalid column value for Node: {col}") def _derived_output_cols(input_cols, column_mapping): outputs = [] for input_col in set(input_cols): for output_col_name, input_col_list in column_mapping.items(): if input_col in input_col_list: outputs.append(output_col_name) return outputs