#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from collections import deque
from typing import Dict
from merlin.dag.base_operator import BaseOperator
from merlin.dag.node import (
Node,
_combine_schemas,
iter_nodes,
postorder_iter_nodes,
preorder_iter_nodes,
)
from merlin.dag.ops.stat_operator import StatOperator
from merlin.schema import Schema
LOG = logging.getLogger("merlin")
[docs]
class Graph:
"""
Represents an DAG composed of Nodes, each of which contains an operator that
transforms dataframes or dataframe-like data
"""
def __init__(self, output_node: Node):
if isinstance(output_node, BaseOperator):
output_node = Node.construct_from(output_node)
self.output_node = output_node
parents_with_deps = self.output_node.parents_with_dependencies
parents_with_deps.append(output_node)
self.subgraphs: Dict[str, "Graph"] = {}
_find_subgraphs(output_node, self.subgraphs)
[docs]
def subgraph(self, name: str) -> "Graph":
if name not in self.subgraphs:
raise ValueError(f"No subgraph named {name}. Options are: {self.subgraphs.keys()}")
return self.subgraphs[name]
@property
def input_dtypes(self):
if self.input_schema:
return {
name: col_schema.dtype
for name, col_schema in self.input_schema.column_schemas.items()
}
else:
return {}
@property
def output_dtypes(self):
if self.output_schema:
return {
name: col_schema.dtype
for name, col_schema in self.output_schema.column_schemas.items()
}
else:
return {}
@property
def column_mapping(self):
nodes = preorder_iter_nodes(self.output_node)
column_mapping = self.output_node.column_mapping
for node in list(nodes)[1:]:
node_map = node.column_mapping
for output_col, input_cols in column_mapping.items():
early_inputs = []
for input_col in input_cols:
early_inputs += node_map.get(input_col, [input_col])
column_mapping[output_col] = early_inputs
return column_mapping
[docs]
def construct_schema(self, root_schema: Schema, preserve_dtypes=False) -> "Graph":
"""
Given the schema of a dataset to transform, determine the output schema of the graph
Parameters
----------
root_schema : Schema
The schema of a dataset to be transformed with this DAG
preserve_dtypes : bool, optional
Whether to keep any dtypes that may already be present in the schemas, by default False
Returns
-------
Graph
This DAG after the schemas have been filled in
"""
nodes = list(postorder_iter_nodes(self.output_node))
self._compute_node_schemas(root_schema, nodes, preserve_dtypes)
self._validate_node_schemas(root_schema, nodes, preserve_dtypes)
return self
def _compute_node_schemas(self, root_schema, nodes, preserve_dtypes=False):
for node in nodes:
node.compute_schemas(root_schema, preserve_dtypes=preserve_dtypes)
def _validate_node_schemas(self, root_schema, nodes, strict_dtypes=False):
for node in nodes:
node.validate_schemas(root_schema, strict_dtypes=strict_dtypes)
@property
def input_schema(self):
# leaf_node input and output schemas are the same (aka selection)
return _combine_schemas(self.leaf_nodes)
@property
def leaf_nodes(self):
return [node for node in postorder_iter_nodes(self.output_node) if not node.parents]
@property
def output_schema(self):
return self.output_node.output_schema
def _input_columns(self):
input_cols = []
for node in iter_nodes([self.output_node]):
upstream_output_cols = []
for upstream_node in node.parents_with_dependencies:
upstream_output_cols += upstream_node.output_columns.names
upstream_output_cols = _get_unique(upstream_output_cols)
input_cols += list(set(node.input_columns.names) - set(upstream_output_cols))
return _get_unique(input_cols)
[docs]
@classmethod
def get_nodes_by_op_type(cls, nodes, op_type):
return set(node for node in iter_nodes(nodes) if isinstance(node.op, op_type))
[docs]
def clear_stats(self):
"""Removes calculated statistics from each node in the graph
See Also
--------
StatOperator.clear
"""
for stat in Graph.get_nodes_by_op_type([self.output_node], StatOperator):
stat.op.clear()
def _get_schemaless_nodes(nodes):
schemaless_nodes = []
for node in iter_nodes(nodes):
if node.input_schema is None:
schemaless_nodes.append(node)
return set(schemaless_nodes)
def _get_unique(cols):
# Need to preserve order in unique-column list
return list({x: x for x in cols}.keys())
def _find_subgraphs(output_node, subgraphs):
for node in postorder_iter_nodes(output_node):
op = node.op
if op.is_subgraph and op.name:
if op.name in subgraphs:
raise ValueError(
f"Found two subgraphs called {op.name}. "
"Each subgraph must have a unique name."
)
subgraphs[op.name] = op.graph
_find_subgraphs(op.graph.output_node, subgraphs)