#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from collections import deque
from typing import Dict, Optional
from merlin.dag.node import (
Node,
_combine_schemas,
iter_nodes,
postorder_iter_nodes,
preorder_iter_nodes,
)
from merlin.schema import Schema
LOG = logging.getLogger("merlin")
[docs]
class Graph:
"""
Represents an DAG composed of Nodes, each of which contains an operator that
transforms dataframes or dataframe-like data
"""
def __init__(self, output_node: Node, subgraphs: Optional[Dict[str, Node]] = None):
self.output_node = output_node
self.subgraphs = subgraphs or {}
parents_with_deps = self.output_node.parents_with_dependencies
parents_with_deps.append(output_node)
for name, sg in self.subgraphs.items():
if sg not in parents_with_deps:
raise ValueError(
f"The output node of subgraph {name} does not exist in the provided graph."
)
[docs]
def subgraph(self, name: str) -> "Graph":
if name not in self.subgraphs.keys():
raise ValueError(f"No subgraph named {name}. Options are: {self.subgraphs.keys()}")
return Graph(self.subgraphs[name])
@property
def input_dtypes(self):
if self.input_schema:
return {
name: col_schema.dtype
for name, col_schema in self.input_schema.column_schemas.items()
}
else:
return {}
@property
def output_dtypes(self):
if self.output_schema:
return {
name: col_schema.dtype
for name, col_schema in self.output_schema.column_schemas.items()
}
else:
return {}
@property
def column_mapping(self):
nodes = preorder_iter_nodes(self.output_node)
column_mapping = self.output_node.column_mapping
for node in list(nodes)[1:]:
node_map = node.column_mapping
for output_col, input_cols in column_mapping.items():
early_inputs = []
for input_col in input_cols:
early_inputs += node_map.get(input_col, [input_col])
column_mapping[output_col] = early_inputs
return column_mapping
[docs]
def construct_schema(self, root_schema: Schema, preserve_dtypes=False) -> "Graph":
"""
Given the schema of a dataset to transform, determine the output schema of the graph
Parameters
----------
root_schema : Schema
The schema of a dataset to be transformed with this DAG
preserve_dtypes : bool, optional
Whether to keep any dtypes that may already be present in the schemas, by default False
Returns
-------
Graph
This DAG after the schemas have been filled in
"""
nodes = list(postorder_iter_nodes(self.output_node))
self._compute_node_schemas(root_schema, nodes, preserve_dtypes)
self._validate_node_schemas(root_schema, nodes, preserve_dtypes)
return self
def _compute_node_schemas(self, root_schema, nodes, preserve_dtypes=False):
for node in nodes:
node.compute_schemas(root_schema, preserve_dtypes=preserve_dtypes)
def _validate_node_schemas(self, root_schema, nodes, strict_dtypes=False):
for node in nodes:
node.validate_schemas(root_schema, strict_dtypes=strict_dtypes)
@property
def input_schema(self):
# leaf_node input and output schemas are the same (aka selection)
return _combine_schemas(self.leaf_nodes)
@property
def leaf_nodes(self):
return [node for node in postorder_iter_nodes(self.output_node) if not node.parents]
@property
def output_schema(self):
return self.output_node.output_schema
def _input_columns(self):
input_cols = []
for node in iter_nodes([self.output_node]):
upstream_output_cols = []
for upstream_node in node.parents_with_dependencies:
upstream_output_cols += upstream_node.output_columns.names
upstream_output_cols = _get_unique(upstream_output_cols)
input_cols += list(set(node.input_columns.names) - set(upstream_output_cols))
return _get_unique(input_cols)
[docs]
@classmethod
def get_nodes_by_op_type(cls, nodes, op_type):
return set(node for node in iter_nodes(nodes) if isinstance(node.op, op_type))
def _get_schemaless_nodes(nodes):
schemaless_nodes = []
for node in iter_nodes(nodes):
if node.input_schema is None:
schemaless_nodes.append(node)
return set(schemaless_nodes)
def _get_unique(cols):
# Need to preserve order in unique-column list
return list({x: x for x in cols}.keys())