Source code for icenet.data.loaders
import inspect
from icenet.data.loaders.base import IceNetBaseDataLoader
import icenet.data.loaders.dask
import icenet.data.loaders.stdlib
[docs]
class IceNetDataLoaderFactory:
"""A factory class for managing a map of loader names and their corresponding implementation classes.
Attributes:
_loader_map: A dictionary holding loader names against their implementation classes.
"""
def __init__(self):
"""Initialises the IceNetDataLoaderFactory instance and sets up the initial loader map."""
self._loader_map = dict(
dask=icenet.data.loaders.dask.DaskMultiWorkerLoader,
dask_shared=icenet.data.loaders.dask.DaskMultiSharingWorkerLoader,
standard=icenet.data.loaders.stdlib.IceNetDataLoader,
)
[docs]
def add_data_loader(self, loader_name: str, loader_impl: object) -> None:
"""Adds a new loader to the loader map with the given name and implementation class.
Args:
loader_name: The name of the loader.
loader_impl: The implementation class of the loader.
Returns:
None. Updates `_loader_map` attribute in IceNetDataLoaderFactory with specified
loader name and implementation.
Raises:
RuntimeError: If the loader name already exists or if the implementation
class is not a descendant of IceNetBaseDataLoader.
"""
if loader_name not in self._loader_map:
if IceNetBaseDataLoader in inspect.getmro(loader_impl):
self._loader_map[loader_name] = loader_impl
else:
raise RuntimeError("{} is not descended from "
"IceNetBaseDataLoader".format(
loader_impl.__name__))
else:
raise RuntimeError(
"Cannot add {} as already in loader map".format(loader_name))
[docs]
def create_data_loader(self, loader_name, *args, **kwargs) -> object:
"""Creates an instance of a loader based on specified name from the `_loader_map` dict attribute.
Args:
loader_name: The name of the loader.
*args: Additional positional arguments, is passed to the loader constructor.
**kwargs: Additional keyword arguments, is passed to the loader constructor.
Returns:
An instance of the loader class.
Raises:
KeyError: If the loader name does not exist in `_loader_map`.
"""
return self._loader_map[loader_name](*args, **kwargs)
@property
def loader_map(self) -> dict:
"""The loader map dictionary."""
return self._loader_map