Source code for merlin.models.utils.registry

#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import abc
import re

# Camel case to snake case utils
from typing import Generic, TypeVar

_first_cap_re = re.compile("(.)([A-Z][a-z0-9]+)")
_all_cap_re = re.compile("([a-z0-9])([A-Z])")


[docs]def camelcase_to_snakecase(name): s1 = _first_cap_re.sub(r"\1_\2", name) return _all_cap_re.sub(r"\1_\2", s1).lower()
[docs]def snakecase_to_camelcase(name): return "".join([w[0].upper() + w[1:] for w in name.split("_")])
[docs]def default_name(class_or_fn): """Default name for a class or function. This is the naming function by default for registries expecting classes or functions. Parameters ---------- class_or_fn: class or function to be named. Returns ------- Default name for registration. """ return camelcase_to_snakecase(class_or_fn.__name__)
[docs]def default_object_name(obj): return default_name(type(obj))
[docs]class Registry: """ Dict-like class for managing function registrations. Example usage:: my_registry = Registry("custom_name") @my_registry.register def my_func(): pass @my_registry.register() def another_func(): pass @my_registry.register("non_default_name") def third_func(x, y, z): pass def foo(): pass my_registry.register()(foo) my_registry.register("baz")(lambda (x, y): x + y) my_register.register("bar") print(list(my_registry)) # ["my_func", "another_func", "non_default_name", "foo", "baz"] # (order may vary) print(my_registry["non_default_name"] is third_func) # True print("third_func" in my_registry) # False print("bar" in my_registry) # False my_registry["non-existent_key"] # raises KeyError Optional validation, on_set callback and value transform also supported. Parameters ---------- registry_name: str identifier for the given registry. Used in error msgs. default_key_fn: callable, optional function mapping value -> key for registration when a key is not provided validator: callable, optional if given, this is run before setting a given (key, value) pair. Accepts (key, value) and should raise if there is a problem. Overwriting existing keys is not allowed and is checked separately. Values are also checked to be callable separately. on_set: callable, optional callback function accepting (key, value) pair which is run after an item is successfully set. value_transformer: callable, optional if run, `__getitem__` will return value_transformer(key, registered_value). """
[docs] def __init__( self, registry_name, default_key_fn=default_name, validator=None, on_set=None, value_transformer=(lambda k, v: v), ): self._registry = {} self._name = registry_name self._default_key_fn = default_key_fn self._validator = validator self._on_set = on_set self._value_transformer = value_transformer
[docs] @classmethod def class_registry( cls, registry_name, default_key_fn=default_name, validator=None, on_set=None ): return cls( registry_name=registry_name, default_key_fn=default_key_fn, validator=validator, on_set=on_set, value_transformer=(lambda k, v: v()), )
[docs] def default_key(self, value): """Default key used when key not provided. Uses function from __init__.""" return self._default_key_fn(value)
@property def name(self): return self._name
[docs] def validate(self, key, value): """Validation function run before setting. Uses function from __init__.""" if self._validator is not None: self._validator(key, value)
[docs] def on_set(self, key, value): """Callback called on successful set. Uses function from __init__.""" if self._on_set is not None: self._on_set(key, value)
def __setitem__(self, key, value): """Validate, set, and (if successful) call `on_set` for the given item. Parameters ---------- key: key to store value under. If `None`, `self.default_key(value)` is used. value: callable stored under the given key. Raises ------ KeyError: if key is already in registry. """ if key is None: key = self.default_key(value) if not isinstance(key, tuple): key = (key,) for k in key: if k in self: raise KeyError("key %s already registered in registry %s" % (k, self._name)) if not callable(value): raise ValueError("value must be callable") self.validate(k, value) self._registry[k] = value self.on_set(k, value)
[docs] def register(self, key_or_value=None): """Decorator to register a function, or registration itself. This is primarily intended for use as a decorator, either with or without a key/parentheses. Example Usage:: @my_registry.register('key1') def value_fn(x, y, z): pass @my_registry.register() def another_fn(x, y): pass @my_registry.register def third_func(): pass Note if key_or_value is provided as a non-callable, registration only occurs once the returned callback is called with a callable as its only argument:: callback = my_registry.register('different_key') 'different_key' in my_registry # False callback(lambda (x, y): x + y) 'different_key' in my_registry # True Parameters ---------- key_or_value (optional): key to access the registered value with, or the unction itself. If `None` (default), `self.default_key` will be called on `value` once the returned callback is called with `value` as the only arg. If `key_or_value` is itself callable, it is assumed to be the value and the key is given by `self.default_key(key)`. Returns ------- decorated callback, or callback generated a decorated function. """ def decorator(value, key): self[key] = value return value # Handle if decorator was used without parens if callable(key_or_value): return decorator(value=key_or_value, key=None) else: return lambda value: decorator(value, key=key_or_value)
[docs] def register_with_multiple_names(self, *names): return self.register(names)
def __getitem__(self, key): if key not in self: raise KeyError( "%s never registered with registry %s. Available:\n %s" % (key, self.name, display_list_by_prefix(sorted(self), 4)) ) value = self._registry[key] return self._value_transformer(key, value) def __contains__(self, key): return key in self._registry
[docs] def keys(self): return self._registry.keys()
[docs] def values(self): return (self[k] for k in self) # complicated because of transformer
[docs] def items(self): return ((k, self[k]) for k in self) # complicated because of transformer
def __iter__(self): return iter(self._registry) def __len__(self): return len(self._registry) def _clear(self): self._registry.clear()
[docs] def get(self, key, default=None): return self[key] if key in self else default
[docs] def parse(self, class_or_str): if isinstance(class_or_str, str): return self[class_or_str] return class_or_str
RegistryClassT = TypeVar("RegistryClassT")
[docs]class RegistryMixin(Generic[RegistryClassT], abc.ABC): registry: Registry
[docs] @classmethod def parse(cls, class_or_str) -> RegistryClassT: output: RegistryClassT = cls.registry.parse(class_or_str) # type: ignore return output
[docs]def display_list_by_prefix(names_list, starting_spaces=0): """Creates a help string for names_list grouped by prefix.""" cur_prefix, result_lines = None, [] space = " " * starting_spaces for name in sorted(names_list): split = name.split("_", 1) prefix = split[0] if cur_prefix != prefix: result_lines.append(space + prefix + ":") cur_prefix = prefix result_lines.append(space + " * " + name) return "\n".join(result_lines)