merlin.models.utils.registry.Registry

class merlin.models.utils.registry.Registry(registry_name, default_key_fn=<function default_name>, validator=None, on_set=None, value_transformer=<function Registry.<lambda>>)[source]

Bases: object

Dict-like class for managing function registrations.

Example usage:

my_registry = Registry("custom_name")
@my_registry.register
def my_func():
  pass
@my_registry.register()
def another_func():
  pass
@my_registry.register("non_default_name")
def third_func(x, y, z):
  pass
def foo():
  pass
my_registry.register()(foo)
my_registry.register("baz")(lambda (x, y): x + y)
my_register.register("bar")
print(list(my_registry))
# ["my_func", "another_func", "non_default_name", "foo", "baz"]
# (order may vary)
print(my_registry["non_default_name"] is third_func)  # True
print("third_func" in my_registry)                    # False
print("bar" in my_registry)                           # False
my_registry["non-existent_key"]                       # raises KeyError

Optional validation, on_set callback and value transform also supported.

Parameters
  • registry_name (str) – identifier for the given registry. Used in error msgs.

  • default_key_fn (callable, optional) – function mapping value -> key for registration when a key is not provided

  • validator (callable, optional) – if given, this is run before setting a given (key, value) pair. Accepts (key, value) and should raise if there is a problem. Overwriting existing keys is not allowed and is checked separately. Values are also checked to be callable separately.

  • on_set (callable, optional) – callback function accepting (key, value) pair which is run after an item is successfully set.

  • value_transformer (callable, optional) – if run, __getitem__ will return value_transformer(key, registered_value).

__init__(registry_name, default_key_fn=<function default_name>, validator=None, on_set=None, value_transformer=<function Registry.<lambda>>)[source]

Methods

__init__(registry_name[, default_key_fn, …])

class_registry(registry_name[, …])

default_key(value)

Default key used when key not provided.

get(key[, default])

items()

keys()

on_set(key, value)

Callback called on successful set.

parse(class_or_str)

register([key_or_value])

Decorator to register a function, or registration itself.

register_with_multiple_names(*names)

validate(key, value)

Validation function run before setting.

values()

Attributes

name

classmethod class_registry(registry_name, default_key_fn=<function default_name>, validator=None, on_set=None)[source]
default_key(value)[source]

Default key used when key not provided. Uses function from __init__.

property name
validate(key, value)[source]

Validation function run before setting. Uses function from __init__.

on_set(key, value)[source]

Callback called on successful set. Uses function from __init__.

register(key_or_value=None)[source]

Decorator to register a function, or registration itself. This is primarily intended for use as a decorator, either with or without a key/parentheses.

Example Usage:

@my_registry.register('key1')
def value_fn(x, y, z):
  pass
@my_registry.register()
def another_fn(x, y):
  pass
@my_registry.register
def third_func():
  pass

Note if key_or_value is provided as a non-callable, registration only occurs once the returned callback is called with a callable as its only argument:

callback = my_registry.register('different_key')
'different_key' in my_registry  # False
callback(lambda (x, y): x + y)
'different_key' in my_registry  # True
Parameters

(optional) (key_or_value) – key to access the registered value with, or the unction itself. If None (default), self.default_key will be called on value once the returned callback is called with value as the only arg. If key_or_value is itself callable, it is assumed to be the value and the key is given by self.default_key(key).

Returns

Return type

decorated callback, or callback generated a decorated function.

register_with_multiple_names(*names)[source]
keys()[source]
values()[source]
items()[source]
get(key, default=None)[source]
parse(class_or_str)[source]