Source code for carpedm.util.registry

#
# Copyright (C) 2018 Neal Digre.
#
# This software may be modified and distributed under the terms
# of the MIT license. See the LICENSE file for details.
#
#
# Portions of this module are taken or lightly modified from the
# Tensor2Tensor registry module, so here is their license:
#
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Registry for models and tasks.

Define a new models by subclassing models.Model and register it:

.. code-block:: python

    @registry.register_model
    class MyModel(models.Model):
        ...

Access by snake-cased name: ``registry.model("my_model")``.

See all the models registered: ``registry.list_models()``.

References:
    1. Lightly modified `Tensor2Tensor registry`_.

..  _Tensor2Tensor registry: https://github.com/tensorflow/
    tensor2tensor/blob/master/tensor2tensor/util/registry.py

"""
import re


_MODELS = {}
_TASKS = {}


# Camel case to snake case util
_first_cap_re = re.compile("(.)([A-Z][a-z0-9]+)")
_all_cap_re = re.compile("([a-z0-9])([A-Z])")


def _convert_camel_to_snake(name):
    s1 = _first_cap_re.sub(r"\1_\2", name)
    return _all_cap_re.sub(r"\1_\2", s1).lower()


def _reset():
    for ctr in [_MODELS, _TASKS]:
        ctr.clear()


[docs]def default_name(obj_class): """Convert class name to the registry's default name for the class. Args: obj_class: the name of a class Returns: The registry's default name for the class. """ return _convert_camel_to_snake(obj_class.__name__)
[docs]def default_object_name(obj): """Convert object to the registry's default name for the object class. Args: obj: an object instance Returns: The registry's default name for the class of the object. """ return default_name(obj.__class__)
[docs]def register_model(name=None): """ Register a models. ``name`` defaults to class name snake-cased.""" def decorator(model_cls, registration_name=None): """Registers & returns model_cls.""" model_name = registration_name or default_name(model_cls) if model_name in _MODELS: raise LookupError("Model %s already registered." % model_name) model_cls.REGISTERED_NAME = model_name _MODELS[model_name] = model_cls return model_cls # Handle if decorator was used without parens if callable(name): model_cls = name return decorator(model_cls, registration_name=default_name(model_cls)) return lambda model_cls: decorator(model_cls, name)
[docs]def model(name): """Retrieve a model by name.""" if name not in _MODELS: raise LookupError("Model %s never registered. Available models:\n %s" % (name, "\n".join(list_models()))) return _MODELS[name]
def list_models(): return list(sorted(_MODELS))
[docs]def register_task(name=None): """Register a Task. ``name`` defaults to cls name snake-cased.""" def decorator(t_cls, registration_name=None): """Registers & returns t_cls with registration_name or default.""" t_name = registration_name or default_name(t_cls) if t_name in _TASKS: raise LookupError("Task %s already registered." % t_name) _TASKS[t_name] = t_cls t_cls.name = t_name return t_cls # Handle if decorator was used without parens if callable(name): t_cls = name return decorator(t_cls, registration_name=default_name(t_cls)) return lambda t_cls: decorator(t_cls, name)
[docs]def task(name): """Retrieve a task by name.""" if name not in _TASKS: all_task_names = sorted(list_tasks()) error_lines = ["%s not in the set of supported tasks:" % name ] + all_task_names error_msg = "\n * ".join(error_lines) raise LookupError(error_msg) return _TASKS[name]
def list_tasks(): return list(_TASKS)
[docs]def display_list_by_prefix(names_list, starting_spaces=0): """Creates a help string for ``names_list`` grouped by prefix.""" cur_prefix, result_lines = None, [] space = " " * starting_spaces for name in sorted(names_list): split = name.split("_", 1) prefix = split[0] if cur_prefix != prefix: result_lines.append(space + prefix + ":") cur_prefix = prefix result_lines.append(space + " * " + name) return "\n".join(result_lines)
[docs]def help_string(): """Generate help string with contents of registry.""" help_str = """ Registry contents: ------------------ Models: %s Tasks: %s """ m, tasks = [ display_list_by_prefix(entries, starting_spaces=4) for entries in [ list_models(), list_tasks() ] ] return help_str % (m, tasks)