Source code for merlin.dag.node

#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import collections.abc
from typing import List, Union

from merlin.dag.base_operator import BaseOperator
from merlin.dag.ops import ConcatColumns, SelectionOp, SubsetColumns, SubtractionOp
from merlin.dag.selector import ColumnSelector
from merlin.schema import Schema


[docs] class Node: """A Node is a group of columns that you want to apply the same transformations to. Node's can be transformed by shifting operators on to them, which returns a new Node with the transformations applied. This lets you define a graph of operations that makes up your workflow Parameters ---------- selector: ColumnSelector Defines which columns to select from the input Dataset using column names and tags. """ def __init__(self, selector=None): self.parents = [] self.children = [] self.dependencies = [] self.op = None self.input_schema = None self.output_schema = None if isinstance(selector, list): selector = ColumnSelector(selector) if selector and not isinstance(selector, ColumnSelector): raise TypeError("The selector argument must be a list or a ColumnSelector") if selector is not None: self.op = SelectionOp(selector) self.selector = selector @property def selector(self): return self._selector @selector.setter def selector(self, sel): if isinstance(sel, list): sel = ColumnSelector(sel) self._selector = sel # These methods must maintain grouping
[docs] def add_dependency( self, dep: Union[ str, List[str], ColumnSelector, "Node", List[Union[str, List[str], "Node", ColumnSelector]], ], ): """ Adding a dependency node to this node Parameters ---------- dep : Union[str, ColumnSelector, Node, List[Union[str, Node, ColumnSelector]]] Dependency to be added """ dep_node = Node.construct_from(dep) if not isinstance(dep_node, list): dep_nodes = [dep_node] else: dep_nodes = dep_node for node in dep_nodes: node.children.append(self) self.dependencies.append(dep_node)
[docs] def add_parent( self, parent: Union[ str, List[str], ColumnSelector, "Node", List[Union[str, List[str], "Node", ColumnSelector]], ], ): """ Adding a parent node to this node Parameters ---------- parent : Union[str, ColumnSelector, Node, List[Union[str, Node, ColumnSelector]]] Parent to be added """ parent_nodes = Node.construct_from(parent) if not isinstance(parent_nodes, list): parent_nodes = [parent_nodes] for parent_node in parent_nodes: parent_node.children.append(self) self.parents.extend(parent_nodes)
[docs] def add_child( self, child: Union[ str, List[str], ColumnSelector, "Node", List[Union[str, List[str], "Node", ColumnSelector]], ], ): """ Adding a child node to this node Parameters ---------- child : Union[str, ColumnSelector, Node, List[Union[str, Node, ColumnSelector]]] Child to be added """ child_nodes = Node.construct_from(child) if not isinstance(child_nodes, list): child_nodes = [child_nodes] for child_node in child_nodes: child_node.parents.append(self) self.children.extend(child_nodes)
[docs] def remove_child( self, child: Union[ str, List[str], ColumnSelector, "Node", List[Union[str, List[str], "Node", ColumnSelector]], ], ): """ Removing a child node from this node Parameters ---------- child : Union[str, ColumnSelector, Node, List[Union[str, Node, ColumnSelector]]] Child to be removed """ child_nodes = Node.construct_from(child) if not isinstance(child_nodes, list): child_nodes = [child_nodes] for child_node in child_nodes: if self in child_node.parents: child_node.parents.remove(self) if child_node in self.children: self.children.remove(child_node)
[docs] def compute_schemas(self, root_schema: Schema, preserve_dtypes: bool = False): """ Defines the input and output schema Parameters ---------- root_schema : Schema Schema of the input dataset preserve_dtypes : bool, optional `True` if we don't want to override dtypes in the current schema, by default False """ parents_schema = _combine_schemas(self.parents) deps_schema = _combine_schemas(self.dependencies) parents_selector = _combine_selectors(self.parents) dependencies_selector = _combine_selectors(self.dependencies) # If parent is an addition or selection node, we may need to # propagate grouping unless this node already has a selector if len(self.parents) == 1 and isinstance(self.parents[0].op, (ConcatColumns, SelectionOp)): parents_selector = self.parents[0].selector if not self.selector and self.parents[0].selector and (self.parents[0].selector.names): self.selector = parents_selector self.input_schema = self.op.compute_input_schema( root_schema, parents_schema, deps_schema, self.selector ) self.selector = self.op.compute_selector( self.input_schema, self.selector, parents_selector, dependencies_selector ) prev_output_schema = self.output_schema if preserve_dtypes else None self.output_schema = self.op.compute_output_schema( self.input_schema, self.selector, prev_output_schema )
[docs] def validate_schemas(self, root_schema: Schema, strict_dtypes: bool = False): """ Check if this Node's input schema matches the output schemas of parents and dependencies Parameters ---------- root_schema : Schema Schema of the input dataset strict_dtypes : bool, optional If an error should be raised when column dtypes don't match, by default False Raises ------ ValueError If parents and dependencies don't provide an expected column based on the input schema ValueError If the dtype of a column from parents and dependencies doesn't match the expected dtype based on the input schema """ parents_schema = _combine_schemas(self.parents) deps_schema = _combine_schemas(self.dependencies) ancestors_schema = root_schema + parents_schema + deps_schema for col_name, col_schema in self.input_schema.column_schemas.items(): source_col_schema = ancestors_schema.get(col_name) if not source_col_schema: raise ValueError( f"Missing column '{col_name}' at the input to '{self.op.__class__.__name__}'." ) if strict_dtypes or not self.op.dynamic_dtypes: if source_col_schema.dtype != col_schema.dtype: raise ValueError( f"Mismatched dtypes for column '{col_name}' provided to " f"'{self.op.__class__.__name__}': " f"ancestor nodes provided dtype '{source_col_schema.dtype}', " f"expected dtype '{col_schema.dtype}'." ) self.op.validate_schemas( parents_schema, deps_schema, self.input_schema, self.output_schema, strict_dtypes )
def __rshift__(self, operator): """Transforms this Node by applying an BaseOperator Parameters ----------- operators: BaseOperator or callable Returns ------- Node """ if isinstance(operator, type) and issubclass(operator, BaseOperator): # handle case where an operator class is passed operator = operator() if not isinstance(operator, BaseOperator): raise ValueError(f"Expected operator or callable, got {operator.__class__}") child = type(self)() child.op = operator child.add_parent(self) dependencies = operator.dependencies if dependencies: if not isinstance(dependencies, collections.abc.Sequence): dependencies = [dependencies] for dependency in dependencies: child.add_dependency(dependency) return child def __add__(self, other): """Adds columns from this Node with another to return a new Node Parameters ----------- other: Node or str or list of str Returns ------- Node """ if isinstance(self.op, ConcatColumns): child = self else: # Create a child node child = type(self)() child.op = ConcatColumns(label="+") child.add_parent(self) # The right operand becomes a dependency other_nodes = Node.construct_from(other) other_nodes = [other_nodes] for other_node in other_nodes: # If the other node is a `+` node, we want to collapse it into this `+` node to # avoid creating a cascade of repeated `+`s that we'd need to optimize out by # re-combining them later in order to clean up the graph if not isinstance(other_node, list) and isinstance(other_node.op, ConcatColumns): child.dependencies += other_node.grouped_parents_with_dependencies else: child.add_dependency(other_node) return child # handle the "column_name" + Node case __radd__ = __add__ def __sub__(self, other): """Removes columns from this Node with another to return a new Node Parameters ----------- other: Node or str or list of str Columns to remove Returns ------- Node """ other_nodes = Node.construct_from(other) if not isinstance(other_nodes, list): other_nodes = [other_nodes] child = type(self)() child.add_parent(self) child.op = SubtractionOp() for other_node in other_nodes: if isinstance(other_node.op, SelectionOp) and not other_node.parents_with_dependencies: child.selector += other_node.selector child.op.selector += child.selector else: child.add_dependency(other_node) return child def __rsub__(self, other): left_operand = Node.construct_from(other) right_operand = self if not isinstance(left_operand, list): left_operand = [left_operand] child = type(self)() child.add_parent(left_operand) child.op = SubtractionOp() if ( isinstance(right_operand.op, SelectionOp) and not right_operand.parents_with_dependencies ): child.selector += right_operand.selector child.op.selector += child.selector else: child.add_dependency(right_operand) return child def __getitem__(self, columns): """Selects certain columns from this Node, and returns a new Columngroup with only those columns Parameters ----------- columns: str or list of str Columns to select Returns ------- Node """ col_selector = ColumnSelector(columns) child = type(self)(col_selector) columns = [columns] if not isinstance(columns, list) else columns child.op = SubsetColumns(label=str(list(columns))) child.add_parent(self) return child def __repr__(self): output = " output" if not self.children else "" return f"<Node {self.label}{output}>"
[docs] def remove_inputs(self, input_cols: List[str]) -> List[str]: """ Remove input columns and all output columns that depend on them. Parameters ---------- input_cols : List[str] The input columns to remove Returns ------- List[str] The output columns that were removed """ removed_outputs = _derived_output_cols(input_cols, self.column_mapping) self.input_schema = self.input_schema.without(input_cols) self.output_schema = self.output_schema.without(removed_outputs) if self.selector: self.selector = self.selector.filter_columns(ColumnSelector(input_cols)) return removed_outputs
[docs] def exportable(self, backend: str = None): backends = getattr(self.op, "exportable_backends", []) return hasattr(self.op, "export") and backend in backends
@property def parents_with_dependencies(self): nodes = [] for node in self.parents + self.dependencies: if isinstance(node, list): nodes.extend(node) else: nodes.append(node) return nodes @property def grouped_parents_with_dependencies(self): return self.parents + self.dependencies @property def input_columns(self): if self.input_schema is None: raise RuntimeError( "The input columns aren't computed until the workflow " "is fit to a dataset or input schema." ) if ( self.selector and not self.selector.tags and all(not selector.tags for selector in self.selector.subgroups) ): # To maintain column groupings return self.selector else: return ColumnSelector(self.input_schema.column_names) @property def output_columns(self): if self.output_schema is None: raise RuntimeError( "The output columns aren't computed until the workflow " "is fit to a dataset or input schema." ) return ColumnSelector(self.output_schema.column_names) @property def column_mapping(self): selector = self.selector or ColumnSelector(self.input_schema.column_names) return self.op.column_mapping(selector) @property def dependency_columns(self): return ColumnSelector(_combine_schemas(self.dependencies).column_names) @property def label(self): if self.op and hasattr(self.op, "label"): return self.op.label elif self.op: return str(type(self.op)) elif not self.parents: return f"input cols=[{self._cols_repr}]" else: return "??" @property def _cols_repr(self): if self.input_schema: columns = self.input_schema.column_names elif self.selector: columns = self.selector.names else: columns = [] cols_repr = ", ".join(map(str, columns[:3])) if len(columns) > 3: cols_repr += "..." return cols_repr @property def graph(self): return _to_graphviz(self) Nodable = Union[ "Node", str, List[str], ColumnSelector, List[Union["Node", str, List[str], ColumnSelector]] ]
[docs] @classmethod def construct_from( cls, nodable: Nodable, ): """ Convert Node-like objects to a Node or list of Nodes. Parameters ---------- nodable : Nodable Node-like objects to convert to a Node or list of Nodes. Returns ------- Union["Node", List["Node"]] New Node(s) corresponding to the Node-like input objects Raises ------ TypeError If supplied input cannot be converted to a Node or list of Nodes """ if isinstance(nodable, str): return Node(ColumnSelector([nodable])) if isinstance(nodable, ColumnSelector): return Node(nodable) elif isinstance(nodable, Node): return nodable elif isinstance(nodable, list): if all(isinstance(elem, str) for elem in nodable): return Node(nodable) else: nodes = [Node.construct_from(node) for node in nodable] non_selection_nodes = [ node for node in nodes if not (hasattr(node, "selector") and node.selector) ] selection_nodes = [ node.selector for node in nodes if (hasattr(node, "selector") and node.selector) ] selection_nodes = ( [Node(_combine_selectors(selection_nodes))] if selection_nodes else [] ) return non_selection_nodes + selection_nodes else: raise TypeError( "Unsupported type: Cannot convert object " f"of type {type(nodable)} to Node." )
def iter_nodes(nodes): queue = nodes[:] while queue: current = queue.pop() if isinstance(current, list): queue.extend(current) else: yield current for node in current.parents_with_dependencies: if node not in queue: queue.append(node) # output node (bottom) -> selection leaf nodes (top) def preorder_iter_nodes(nodes): queue = [] if not isinstance(nodes, list): nodes = [nodes] def traverse(current_nodes): for node in current_nodes: # Avoid creating duplicate nodes in the queue if node in queue: queue.remove(node) queue.append(node) for node in current_nodes: traverse(node.parents_with_dependencies) traverse(nodes) for node in queue: yield node # selection leaf nodes (top) -> output node (bottom) def postorder_iter_nodes(nodes): queue = [] if not isinstance(nodes, list): nodes = [nodes] def traverse(current_nodes): for node in current_nodes: traverse(node.parents_with_dependencies) if node not in queue: queue.append(node) traverse(nodes) for node in queue: yield node def _filter_by_type(elements, type_): results = [] for elem in elements: if isinstance(elem, type_): results.append(elem) elif isinstance(elem, list): results += _filter_by_type(elem, type_) return results def _combine_schemas(elements): combined = Schema() for elem in elements: if isinstance(elem, Node): combined += elem.output_schema elif isinstance(elem, ColumnSelector): combined += Schema(elem.names) elif isinstance(elem, list): combined += _combine_schemas(elem) return combined def _combine_selectors(elements): combined = ColumnSelector() for elem in elements: if isinstance(elem, Node): if elem.selector: selector = elem.op.output_column_names(elem.selector) elif elem.output_schema: selector = ColumnSelector(elem.output_schema.column_names) elif elem.input_schema: selector = ColumnSelector(elem.input_schema.column_names) selector = elem.op.output_column_names(selector) else: selector = ColumnSelector() combined += selector elif isinstance(elem, ColumnSelector): combined += elem elif isinstance(elem, str): combined += ColumnSelector(elem) elif isinstance(elem, list): combined += ColumnSelector(subgroups=_combine_selectors(elem)) return combined def _to_selector(value): if not isinstance(value, (ColumnSelector, Node)): return ColumnSelector(value) else: return value def _strs_to_selectors(elements): return [_to_selector(elem) for elem in elements] def _to_graphviz(output_node): """Converts a Node to a GraphViz DiGraph object useful for display in notebooks""" from graphviz import Digraph graph = Digraph() # get all the nodes from parents of this columngroup # and add edges between each of them allnodes = list(set(iter_nodes([output_node]))) node_ids = {v: str(k) for k, v in enumerate(allnodes)} for node, nodeid in node_ids.items(): graph.node(nodeid, node.label) for parent in node.parents_with_dependencies: graph.edge(node_ids[parent], nodeid) if node.selector and node.selector.names: selector_id = f"{nodeid}_selector" graph.node(selector_id, str(node.selector.names)) graph.edge(selector_id, nodeid) # add a single node representing the final state final_node_id = str(len(allnodes)) final_string = "output cols" if output_node._cols_repr: final_string += f"=[{output_node._cols_repr}]" graph.node(final_node_id, final_string) graph.edge(node_ids[output_node], final_node_id) return graph def _convert_col(col): if isinstance(col, (str, tuple)): return col elif isinstance(col, list): return tuple(col) else: raise ValueError(f"Invalid column value for Node: {col}") def _derived_output_cols(input_cols, column_mapping): outputs = [] for input_col in set(input_cols): for output_col_name, input_col_list in column_mapping.items(): if input_col in input_col_list: outputs.append(output_col_name) return outputs