#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Union
import nvtabular
from nvtabular.tags import Tags
[docs]class ColumnSelector:
"""A ColumnSelector describes a group of columns to be transformed by Operators in a
Workflow. Operators can be applied to the selected columns by shifting (>>) operators
on to the ColumnSelector, which returns a new WorkflowNode with the transformations
applied. This lets you define a graph of operations that makes up your Workflow.
Parameters
----------
names: list of (str or tuple of str)
The columns to select from the input Dataset. The elements of this list are strings
indicating the column names in most cases, but can also be tuples of strings
for feature crosses.
subgroups, optional: list of ColumnSelector objects
This provides an alternate syntax for grouping column names together (instead
of nesting tuples inside the list of names)
"""
def __init__(
self,
names: List[str] = None,
subgroups: List["ColumnSelector"] = None,
tags: List[Union[Tags, str]] = None,
):
self._names = names if names is not None else []
self._tags = tags if tags is not None else []
self.subgroups = subgroups if subgroups is not None else []
if isinstance(self._names, nvtabular.WorkflowNode):
raise TypeError("ColumnSelectors can not contain WorkflowNodes")
if isinstance(self._names, str):
self._names = [self._names]
if isinstance(self.subgroups, ColumnSelector):
self.subgroups = [self.subgroups]
plain_names = []
for name in self._names:
if isinstance(name, str):
plain_names.append(name)
elif isinstance(name, nvtabular.WorkflowNode):
raise ValueError("ColumnSelectors can not contain WorkflowNodes")
elif isinstance(name, ColumnSelector):
self.subgroups.append(name)
else:
self.subgroups.append(ColumnSelector(name))
self._names = plain_names
self._nested_check()
@property
def tags(self):
return list(dict.fromkeys(self._tags).keys())
@property
def names(self):
names = []
names += self._names
for subgroup in self.subgroups:
names += subgroup.names
# Only return unique column names
return list(dict.fromkeys(names).keys())
@property
def grouped_names(self):
names = []
names += self._names
for subgroup in self.subgroups:
names.append(tuple(subgroup.names))
# Only return unique grouped column names
return list(dict.fromkeys(names).keys())
def _nested_check(self, nests=0):
if nests > 1:
raise AttributeError("Too many nested subgroups")
for col_sel0 in self.subgroups:
col_sel0._nested_check(nests=nests + 1)
def __add__(self, other):
if other is None:
return self
elif isinstance(other, nvtabular.WorkflowNode):
return other + self
elif isinstance(other, ColumnSelector):
return ColumnSelector(
self._names + other._names,
self.subgroups + other.subgroups,
tags=self._tags + other._tags,
)
elif isinstance(other, Tags):
return ColumnSelector(self._names, self.subgroups, tags=self._tags + [other])
else:
if isinstance(other, str):
other = [other]
return ColumnSelector(self._names + other, self.subgroups)
def __radd__(self, other):
return self + other
def __rshift__(self, other):
# Create a selection node to shift onto, then shift onto it
return nvtabular.WorkflowNode(self) >> other
def __eq__(self, other):
if not isinstance(other, ColumnSelector):
return False
return other._names == self._names and other.subgroups == self.subgroups