Source code for carpedm.util.registry

# Copyright (C) 2018 Neal Digre.
# This software may be modified and distributed under the terms
# of the MIT license. See the LICENSE file for details.
# Portions of this module are taken or lightly modified from the
# Tensor2Tensor registry module, so here is their license:
# Copyright 2018 The Tensor2Tensor Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

"""Registry for models and tasks.

Define a new models by subclassing models.Model and register it:

.. code-block:: python

    class MyModel(models.Model):

Access by snake-cased name: ``registry.model("my_model")``.

See all the models registered: ``registry.list_models()``.

    1. Lightly modified `Tensor2Tensor registry`_.

..  _Tensor2Tensor registry:

import re

_MODELS = {}
_TASKS = {}

# Camel case to snake case util
_first_cap_re = re.compile("(.)([A-Z][a-z0-9]+)")
_all_cap_re = re.compile("([a-z0-9])([A-Z])")

def _convert_camel_to_snake(name):
    s1 = _first_cap_re.sub(r"\1_\2", name)
    return _all_cap_re.sub(r"\1_\2", s1).lower()

def _reset():
    for ctr in [_MODELS, _TASKS]:

[docs]def default_name(obj_class): """Convert class name to the registry's default name for the class. Args: obj_class: the name of a class Returns: The registry's default name for the class. """ return _convert_camel_to_snake(obj_class.__name__)
[docs]def default_object_name(obj): """Convert object to the registry's default name for the object class. Args: obj: an object instance Returns: The registry's default name for the class of the object. """ return default_name(obj.__class__)
[docs]def register_model(name=None): """ Register a models. ``name`` defaults to class name snake-cased.""" def decorator(model_cls, registration_name=None): """Registers & returns model_cls.""" model_name = registration_name or default_name(model_cls) if model_name in _MODELS: raise LookupError("Model %s already registered." % model_name) model_cls.REGISTERED_NAME = model_name _MODELS[model_name] = model_cls return model_cls # Handle if decorator was used without parens if callable(name): model_cls = name return decorator(model_cls, registration_name=default_name(model_cls)) return lambda model_cls: decorator(model_cls, name)
[docs]def model(name): """Retrieve a model by name.""" if name not in _MODELS: raise LookupError("Model %s never registered. Available models:\n %s" % (name, "\n".join(list_models()))) return _MODELS[name]
def list_models(): return list(sorted(_MODELS))
[docs]def register_task(name=None): """Register a Task. ``name`` defaults to cls name snake-cased.""" def decorator(t_cls, registration_name=None): """Registers & returns t_cls with registration_name or default.""" t_name = registration_name or default_name(t_cls) if t_name in _TASKS: raise LookupError("Task %s already registered." % t_name) _TASKS[t_name] = t_cls = t_name return t_cls # Handle if decorator was used without parens if callable(name): t_cls = name return decorator(t_cls, registration_name=default_name(t_cls)) return lambda t_cls: decorator(t_cls, name)
[docs]def task(name): """Retrieve a task by name.""" if name not in _TASKS: all_task_names = sorted(list_tasks()) error_lines = ["%s not in the set of supported tasks:" % name ] + all_task_names error_msg = "\n * ".join(error_lines) raise LookupError(error_msg) return _TASKS[name]
def list_tasks(): return list(_TASKS)
[docs]def display_list_by_prefix(names_list, starting_spaces=0): """Creates a help string for ``names_list`` grouped by prefix.""" cur_prefix, result_lines = None, [] space = " " * starting_spaces for name in sorted(names_list): split = name.split("_", 1) prefix = split[0] if cur_prefix != prefix: result_lines.append(space + prefix + ":") cur_prefix = prefix result_lines.append(space + " * " + name) return "\n".join(result_lines)
[docs]def help_string(): """Generate help string with contents of registry.""" help_str = """ Registry contents: ------------------ Models: %s Tasks: %s """ m, tasks = [ display_list_by_prefix(entries, starting_spaces=4) for entries in [ list_models(), list_tasks() ] ] return help_str % (m, tasks)