Source code for icenet.data.loaders.base

import datetime as dt
import json
import logging
import os
from abc import abstractmethod

from pprint import pformat

import numpy as np

from icenet.data.process import IceNetPreProcessor
from icenet.data.producers import Generator
"""

"""


[docs] class IceNetBaseDataLoader(Generator): """ :param configuration_path, :param identifier, :param var_lag, :param dataset_config_path: :param generate_workers: :param loss_weight_days: :param n_forecast_days: :param output_batch_size: :param path: :param var_lag_override: """ def __init__(self, configuration_path: str, identifier: str, var_lag: int, *args, dataset_config_path: str = ".", dates_override: object = None, dry: bool = False, generate_workers: int = 8, loss_weight_days: bool = True, n_forecast_days: int = 93, output_batch_size: int = 32, path: str = os.path.join(".", "network_datasets"), pickup: bool = False, var_lag_override: object = None, **kwargs): super().__init__(*args, identifier=identifier, path=path, **kwargs) self._channels = dict() self._channel_files = dict() self._configuration_path = configuration_path self._dataset_config_path = dataset_config_path self._dates_override = dates_override self._config = dict() self._dry = dry self._loss_weight_days = loss_weight_days self._meta_channels = [] self._missing_dates = [] self._n_forecast_days = n_forecast_days self._output_batch_size = output_batch_size self._pickup = pickup self._trend_steps = dict() self._workers = generate_workers self._var_lag = var_lag self._var_lag_override = dict() \ if not var_lag_override else var_lag_override self._load_configuration(configuration_path) self._construct_channels() self._dtype = getattr(np, self._config["dtype"]) self._shape = tuple(self._config["shape"]) self._missing_dates = [ dt.datetime.strptime(s, IceNetPreProcessor.DATE_FORMAT) for s in self._config["missing_dates"] ]
[docs] def write_dataset_config_only(self): """ """ splits = ("train", "val", "test") counts = {el: 0 for el in splits} logging.info("Writing dataset configuration without data generation") # FIXME: cloned mechanism from generate() - do we need to treat these as # sets that might have missing data for fringe cases? for dataset in splits: forecast_dates = sorted( list( set([ dt.datetime.strptime( s, IceNetPreProcessor.DATE_FORMAT).date() for identity in self._config["sources"].keys() for s in self._config["sources"][identity]["dates"] [dataset] ]))) logging.info("{} {} dates in total, NOT generating cache " "data.".format(len(forecast_dates), dataset)) counts[dataset] += len(forecast_dates) self._write_dataset_config(counts, network_dataset=False)
[docs] @abstractmethod def generate_sample(self, date: object, prediction: bool = False): """ :param date: :param prediction: :return: """ pass
[docs] def get_sample_files(self) -> object: """ :param date: :return: """ # FIXME: is this not just the same as _channel_files now? # FIXME: still experimental code, move to multiple implementations # FIXME: CLEAN THIS ALL UP ONCE VERIFIED FOR local/shared STORAGE! var_files = dict() for var_name, num_channels in self._channels.items(): var_file = self._get_var_file(var_name) if not var_file: raise RuntimeError("No file returned for {}".format(var_name)) if var_name not in var_files: var_files[var_name] = var_file elif var_file != var_files[var_name]: raise RuntimeError("Differing files? {} {} vs {}".format( var_name, var_file, var_files[var_name])) return var_files
def _add_channel_files(self, var_name: str, filelist: object): """ :param var_name: :param filelist: """ if var_name in self._channel_files: logging.warning("{} already has files, but more found, " "this could be an unintentional merge of " "sources".format(var_name)) else: self._channel_files[var_name] = [] logging.debug("Adding {} to {} channel".format(len(filelist), var_name)) self._channel_files[var_name] += filelist def _construct_channels(self): """ """ # As of Python 3.7 dict guarantees the order of keys based on # original insertion order, which is great for this method lag_vars = [ (identity, var, data_format) for data_format in ("abs", "anom") for identity in sorted(self._config["sources"].keys()) for var in sorted(self._config["sources"][identity][data_format]) ] for identity, var_name, data_format in lag_vars: var_prefix = "{}_{}".format(var_name, data_format) var_lag = (self._var_lag if var_name not in self._var_lag_override else self._var_lag_override[var_name]) self._channels[var_prefix] = int(var_lag) self._add_channel_files(var_prefix, [ el for el in self._config["sources"][identity]["var_files"] [var_name] if var_prefix in os.path.split(el)[1] ]) trend_names = [(identity, var, self._config["sources"][identity]["linear_trend_steps"]) for identity in sorted(self._config["sources"].keys()) for var in sorted(self._config["sources"][identity] ["linear_trends"])] for identity, var_name, trend_steps in trend_names: var_prefix = "{}_linear_trend".format(var_name) self._channels[var_prefix] = len(trend_steps) self._trend_steps[var_prefix] = trend_steps filelist = [ el for el in self._config["sources"][identity]["var_files"] [var_name] if "linear_trend" in os.path.split(el)[1] ] self._add_channel_files(var_prefix, filelist) # Metadata input variables that don't span time meta_names = [ (identity, var) for identity in sorted(self._config["sources"].keys()) for var in sorted(self._config["sources"][identity]["meta"]) ] for identity, var_name in meta_names: self._meta_channels.append(var_name) self._channels[var_name] = 1 self._add_channel_files( var_name, self._config["sources"][identity]["var_files"][var_name]) logging.debug( "Channel quantities deduced:\n{}\n\nTotal channels: {}".format( pformat(self._channels), self.num_channels)) def _get_var_file(self, var_name: str): """ :param var_name: :return: """ filename = "{}.nc".format(var_name) files = self._channel_files[var_name] if len(self._channel_files[var_name]) > 1: logging.warning( "Multiple files found for {}, only returning {}".format( filename, files[0])) elif not len(files): logging.warning("No files in channel list for {}".format(filename)) return None return files[0] def _load_configuration(self, path: str): """ :param path: """ if os.path.exists(path): logging.info("Loading configuration {}".format(path)) with open(path, "r") as fh: obj = json.load(fh) self._config.update(obj) else: raise OSError("{} not found".format(path)) def _write_dataset_config(self, counts: object, network_dataset: bool = True): """ :param counts: :param network_dataset: :return: """ # TODO: move to utils for this and process def _serialize(x): if x is dt.date: return x.strftime(IceNetPreProcessor.DATE_FORMAT) return str(x) configuration = { "identifier": self.identifier, "implementation": self.__class__.__name__, # This is only for convenience ;) "channels": [ "{}_{}".format(channel, i) for channel, s in self._channels.items() for i in range(1, s + 1) ], "counts": counts, "dtype": self._dtype.__name__, "loader_config": os.path.abspath(self._configuration_path), "missing_dates": [ date.strftime(IceNetPreProcessor.DATE_FORMAT) for date in self._missing_dates ], "n_forecast_days": self._n_forecast_days, "north": self.north, "num_channels": self.num_channels, # FIXME: this naming is inconsistent, sort it out!!! ;) "shape": list(self._shape), "south": self.south, # For recreating this dataloader # "dataset_config_path = ".", "dataset_path": self._path if network_dataset else False, "generate_workers": self.workers, "loss_weight_days": self._loss_weight_days, "output_batch_size": self._output_batch_size, "var_lag": self._var_lag, "var_lag_override": self._var_lag_override, } output_path = os.path.join( self._dataset_config_path, "dataset_config.{}.json".format(self.identifier)) logging.info("Writing configuration to {}".format(output_path)) with open(output_path, "w") as fh: json.dump(configuration, fh, indent=4, default=_serialize) @property def channel_names(self): return [ "{}_{}".format(nom, idx) if idx_qty > 1 else nom for nom, idx_qty in self._channels.items() for idx in range(1, idx_qty + 1) ] @property def config(self): return self._config @property def dates_override(self): return self._dates_override @property def num_channels(self): return sum(self._channels.values()) @property def pickup(self): return self._pickup @property def workers(self): return self._workers