import argparse
import json
import logging
import os
import dask
import numpy as np
import pandas as pd
from icenet.data.datasets.utils import SplittingMixin
from icenet.data.loader import IceNetDataLoaderFactory
from icenet.data.producers import DataCollection
from icenet.utils import setup_logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
pytorch_available = False
try:
from torch.utils.data import Dataset
except ModuleNotFoundError:
print("PyTorch not found - not required if not using PyTorch")
except ImportError:
print("PyTorch import failed - not required if not using PyTorch")
"""
https://stackoverflow.com/questions/55852831/
tf-data-vs-keras-utils-sequence-performance
"""
[docs]
class IceNetDataSet(SplittingMixin, DataCollection):
"""Initialises and configures a dataset.
It loads a JSON configuration file, updates the `_config` attribute with the
result, creates a data loader, and methods to access the dataset.
Attributes:
_config: A dict used to store configuration loaded from JSON file.
_configuration_path: The path to the JSON configuration file.
_batch_size: The batch size for the data loader.
_counts: A dict with number of elements in train, val, test.
_dtype: The type of the dataset.
_loader_config: The path to the data loader configuration file.
_generate_workers: An integer representing number of workers for parallel processing with Dask.
_n_forecast_days: An integer representing number of days to predict for.
_num_channels: An integer representing number of channels (input variables) in the dataset.
_shape: The shape of the dataset.
_shuffling: A flag indicating whether to shuffle the data or not.
"""
def __init__(self,
configuration_path: str,
*args,
batch_size: int = 4,
path: str = os.path.join(".", "network_datasets"),
shuffling: bool = False,
**kwargs) -> None:
"""Initialises an instance of the IceNetDataSet class.
Args:
configuration_path: The path to the JSON configuration file.
*args: Additional positional arguments.
batch_size (optional): How many samples to load per batch. Defaults to 4.
path (optional): The path to the directory where the processed tfrecord
protocol buffer files will be stored. Defaults to './network_datasets'.
shuffling (optional): Flag indicating whether to shuffle the data.
Defaults to False.
*args: Additional keyword arguments.
"""
self._config = dict()
self._configuration_path = configuration_path
self._load_configuration(configuration_path)
super().__init__(*args,
identifier=self._config["identifier"],
north=bool(self._config["north"]),
path=path,
south=bool(self._config["south"]),
**kwargs)
self._batch_size = batch_size
self._counts = self._config["counts"]
self._dtype = getattr(np, self._config["dtype"])
self._loader_config = self._config["loader_config"]
self._generate_workers = self._config["generate_workers"]
self._n_forecast_days = self._config["n_forecast_days"]
self._num_channels = self._config["num_channels"]
self._shape = tuple(self._config["shape"])
self._shuffling = shuffling
if "loader_path" in self._config:
logging.warning("Configuration uses old \"loader_path\" attribute, "
"this should change to \"dataset_path\"")
path_attr = "loader_path"
else:
path_attr = "dataset_path"
# Check JSON config has attribute for path to tfrecord datasets, and
# that the path exists.
if self._config[path_attr] and \
os.path.exists(self._config[path_attr]):
hemi = self.hemisphere_str[0]
self.add_records(self.base_path, hemi)
else:
logging.warning("Running in configuration only mode, tfrecords "
"were not generated for this dataset")
def _load_configuration(self, path: str) -> None:
"""Load the JSON configuration file and update the `_config` attribute of `IceNetDataSet` class.
Args:
path: The path to the JSON configuration file.
Raises:
OSError: If the specified configuration file is not found.
"""
if os.path.exists(path):
logging.info("Loading configuration {}".format(path))
with open(path, "r") as fh:
obj = json.load(fh)
self._config.update(obj)
else:
raise OSError("{} not found".format(path))
[docs]
def get_data_loader(self,
n_forecast_days: object = None,
generate_workers: object = None) -> object:
"""Create an instance of the IceNetDataLoader class.
Args:
n_forecast_days (optional): The number of forecast days to be used by the data loader.
If not provided, defaults to the value specified in the configuration file.
generate_workers (optional): An integer representing number of workers to use for
parallel processing with Dask. If not provided, defaults to the value specified in
the configuration file.
Returns:
An instance of the DaskMultiWorkerLoader class configured with the specified parameters.
"""
if n_forecast_days is None:
n_forecast_days = self._config["n_forecast_days"]
if generate_workers is None:
generate_workers = self._config["generate_workers"]
loader = IceNetDataLoaderFactory().create_data_loader(
"dask", # This will load the `DaskMultiWorkerLoader` class.
self.loader_config,
self.identifier,
self._config["var_lag"],
n_forecast_days=n_forecast_days,
generate_workers=generate_workers,
dataset_config_path=os.path.dirname(self._configuration_path),
loss_weight_days=self._config["loss_weight_days"],
north=self.north,
output_batch_size=self._config["output_batch_size"],
south=self.south,
var_lag_override=self._config["var_lag_override"],
)
return loader
@property
def loader_config(self) -> str:
"""The path to the JSON loader configuration file stored in the dataset config file."""
# E.g. `/path/to/loader.{identifier}.json`
return self._loader_config
@property
def channels(self) -> list:
"""The list of channels (variable names) specified in the dataset config file."""
return self._config["channels"]
@property
def counts(self) -> dict:
"""A dict with number of elements in train, val, test in the config file."""
return self._config["counts"]
[docs]
class MergedIceNetDataSet(SplittingMixin, DataCollection):
"""
:param identifier:
:param configuration_paths: List of configurations to load
:param batch_size:
:param path:
"""
def __init__(self,
configuration_paths: object,
*args,
batch_size: int = 4,
path: str = os.path.join(".", "network_datasets"),
shuffling: bool = False,
**kwargs):
self._config = dict()
self._configuration_paths = [configuration_paths] \
if type(configuration_paths) != list else configuration_paths
self._load_configurations(configuration_paths)
identifier = ".".join(
[loader.identifier for loader in self._config["loaders"]])
super().__init__(*args,
identifier=identifier,
north=bool(self._config["north"]),
path=path,
south=bool(self._config["south"]),
**kwargs)
self._base_path = path
self._batch_size = batch_size
self._dtype = getattr(np, self._config["dtype"])
self._num_channels = self._config["num_channels"]
self._n_forecast_days = self._config["n_forecast_days"]
self._shape = self._config["shape"]
self._shuffling = shuffling
self._init_records()
def _init_records(self):
"""
"""
for idx, loader_path in enumerate(self._config["loader_paths"]):
hemi = self._config["loaders"][idx].hemisphere_str[0]
base_path = os.path.join(self._base_path,
self._config["loaders"][idx].identifier)
self.add_records(base_path, hemi)
def _load_configurations(self, paths: object):
"""
:param paths:
"""
self._config = dict(loader_paths=[],
loaders=[],
north=False,
south=False)
for path in paths:
if os.path.exists(path):
logging.info("Loading configuration {}".format(path))
with open(path, "r") as fh:
obj = json.load(fh)
self._merge_configurations(path, obj)
else:
raise OSError("{} not found".format(path))
def _merge_configurations(self, path: str, other: object):
"""
:param path:
:param other:
"""
loader = IceNetDataLoaderFactory().create_data_loader(
"dask",
other["loader_config"],
other["identifier"],
other["var_lag"],
dataset_config_path=os.path.dirname(path),
loss_weight_days=other["loss_weight_days"],
north=other["north"],
output_batch_size=other["output_batch_size"],
south=other["south"],
var_lag_override=other["var_lag_override"])
self._config["loaders"].append(loader)
if "loader_path" in other:
logging.warning("Configuration uses old \"loader_path\" attribute, "
"this should change to \"dataset_path\"")
self._config["loader_paths"].append(other["loader_path"])
else:
self._config["loader_paths"].append(other["dataset_path"])
if "counts" not in self._config:
self._config["counts"] = other["counts"].copy()
else:
for dataset, count in other["counts"].items():
logging.info("Merging {} samples from {}".format(
count, dataset))
self._config["counts"][dataset] += count
general_attrs = [
"channels", "dtype", "n_forecast_days", "num_channels",
"output_batch_size", "shape"
]
for attr in general_attrs:
if attr not in self._config:
self._config[attr] = other[attr]
else:
assert self._config[attr] == other[attr], \
"{} is not the same across configurations".format(attr)
self._config["north"] = True if loader.north else self._config["north"]
self._config["south"] = True if loader.south else self._config["south"]
[docs]
def get_data_loader(self):
"""
:return:
"""
assert len(self._configuration_paths) == 1, "Configuration mode is " \
"only for single loader" \
"datasets: {}".format(
self._configuration_paths
)
return self._config["loader"][0]
[docs]
def check_dataset(self, split: str = "train"):
"""
:param split:
"""
raise NotImplementedError("Checking not implemented for merged sets, "
"consider doing them individually")
@property
def channels(self):
return self._config['channels']
@property
def counts(self):
return self._config["counts"]
if pytorch_available:
class IceNetDataSetPyTorch(IceNetDataSet, Dataset):
"""Initialises and configures a PyTorch dataset.
"""
def __init__(
self,
configuration_path: str,
mode: str,
batch_size: int = 1,
shuffling: bool = False,
):
"""Initialises an instance of the IceNetDataSetPyTorch class.
Args:
configuration_path: The path to the JSON configuration file.
mode: The dataset type, i.e. `train`, `val` or `test`.
batch_size (optional): How many samples to load per batch. Defaults to 1.
shuffling (optional): Flag indicating whether to shuffle the data.
Defaults to False.
"""
super().__init__(configuration_path=configuration_path,
batch_size=batch_size,
shuffling=shuffling)
self._dl = self.get_data_loader()
# check mode option
if mode not in ["train", "val", "test"]:
raise ValueError("mode must be either 'train', 'val', 'test'")
self._mode = mode
self._dates = self._dl._config["sources"]["osisaf"]["dates"][self._mode]
def __len__(self):
return self._counts[self._mode]
def __getitem__(self, idx):
"""Return a sample from the dataloader for given index.
"""
with dask.config.set(scheduler="synchronous"):
sample = self._dl.generate_sample(
date=pd.Timestamp(self._dates[idx].replace('_', '-')),
parallel=False,
)
return sample
@property
def dates(self):
return self._dates
[docs]
@setup_logging
def get_args() -> object:
"""Parse command line arguments using the argparse module.
Returns:
An object containing the parsed command line arguments.
Example:
Assuming CLI arguments provided.
args = get_args()
print(args.dataset)
print(args.split)
print(args.verbose)
"""
ap = argparse.ArgumentParser()
ap.add_argument("dataset")
ap.add_argument("-s",
"--split",
choices=["train", "val", "test"],
default="train")
ap.add_argument("-v", "--verbose", action="store_true", default=False)
args = ap.parse_args()
return args
[docs]
def check_dataset() -> None:
"""Check the dataset for a specific split."""
args = get_args()
ds = IceNetDataSet(args.dataset)
ds.check_dataset(args.split)