import glob
import logging
import os
import numpy as np
import tensorflow as tf
[docs]
def get_decoder(shape: object,
channels: object,
forecasts: object,
num_vars: int = 1,
dtype: str = "float32") -> object:
"""Returns a decoder function used for parsing and decoding data from tfrecord protocol buffer.
Args:
shape: The shape of the input data.
channels: The number of channels in the input data.
forecasts: The number of days to forecast in prediction
num_vars (optional): The number of variables in the input data. Defaults to 1.
dtype (optional): The data type of the input data. Defaults to "float32".
Returns:
A function that can be used to parse and decode data. It takes in a protocol buffer
(tfrecord) as input and returns the parsed and decoded data.
"""
xf = tf.io.FixedLenFeature([*shape, channels], getattr(tf, dtype))
yf = tf.io.FixedLenFeature([*shape, forecasts, num_vars],
getattr(tf, dtype))
sf = tf.io.FixedLenFeature([*shape, forecasts, num_vars],
getattr(tf, dtype))
@tf.function
def decode_item(proto):
features = {
"x": xf,
"y": yf,
"sample_weights": sf,
}
item = tf.io.parse_example(proto, features)
return item['x'], item['y'], item['sample_weights']
return decode_item
# TODO: define a decent interface and sort the inheritance architecture out, as
# this will facilitate the new datasets in #35
[docs]
class SplittingMixin:
"""Read train, val, test datasets from tfrecord protocol buffer files.
Split and shuffle data if specified as well.
Example:
This mixin is not to be used directly, but to give an idea of its use:
# Initialise SplittingMixin
>>> split_dataset = SplittingMixin()
# Add file paths to the train, validation, and test datasets
>>> split_dataset.add_records(base_path="./network_datasets/notebook_data/", hemi="south")
"""
_batch_size: int
_dtype: object
_num_channels: int
_n_forecast_days: int
_shape: int
_shuffling: bool
train_fns = []
test_fns = []
val_fns = []
[docs]
def add_records(self, base_path: str, hemi: str) -> None:
"""Add list of paths to train, val, test *.tfrecord(s) to relevant instance attributes.
Add sorted list of file paths to train, validation, and test datasets in SplittingMixin.
Args:
base_path (str): The base path where the datasets are located.
hemi (str): The hemisphere the datasets correspond to.
Returns:
None. Updates `self.train_fns`, `self.val_fns`, `self.test_fns` with list
of *.tfrecord files.
"""
train_path = os.path.join(base_path, hemi, "train")
val_path = os.path.join(base_path, hemi, "val")
test_path = os.path.join(base_path, hemi, "test")
logging.info("Training dataset path: {}".format(train_path))
self.train_fns += sorted(glob.glob("{}/*.tfrecord".format(train_path)))
logging.info("Validation dataset path: {}".format(val_path))
self.val_fns += sorted(glob.glob("{}/*.tfrecord".format(val_path)))
logging.info("Test dataset path: {}".format(test_path))
self.test_fns += sorted(glob.glob("{}/*.tfrecord".format(test_path)))
[docs]
def get_split_datasets(self, ratio: object = None):
"""Retrieves train, val, and test datasets from corresponding attributes of SplittingMixin.
Retrieves the train, validation, and test datasets from the file paths stored in the
`train_fns`, `val_fns`, and `test_fns` attributes of SplittingMixin.
Args:
ratio (optional): A float representing the truncated list of datasets to be used.
If not specified, all datasets will be used.
Defaults to None.
Returns:
tuple: A tuple containing the train, validation, and test datasets.
Raises:
RuntimeError: If no files have been found in the train, validation, and test datasets.
RuntimeError: If the ratio is greater than 1.
"""
if not (len(self.train_fns) + len(self.val_fns) + len(self.test_fns)):
raise RuntimeError("No files have been found, abandoning. This is "
"likely because you're trying to use a config "
"only mode dataset in a situation that demands "
"tfrecords to be generated (like training...)")
logging.info("Datasets: {} train, {} val and {} test filenames".format(
len(self.train_fns), len(self.val_fns), len(self.test_fns)))
# If ratio is specified, truncate file paths for train, val, test using the ratio.
if ratio:
if ratio > 1.0:
raise RuntimeError("Ratio cannot be more than 1")
logging.info("Reducing datasets to {} of total files".format(ratio))
train_idx, val_idx, test_idx = \
int(len(self.train_fns) * ratio), \
int(len(self.val_fns) * ratio), \
int(len(self.test_fns) * ratio)
if train_idx > 0:
self.train_fns = self.train_fns[:train_idx]
if val_idx > 0:
self.val_fns = self.val_fns[:val_idx]
if test_idx > 0:
self.test_fns = self.test_fns[:test_idx]
logging.info(
"Reduced: {} train, {} val and {} test filenames".format(
len(self.train_fns), len(self.val_fns), len(self.test_fns)))
# Loads from files as bytes exactly as written. Must parse and decode it.
train_ds, val_ds, test_ds = \
tf.data.TFRecordDataset(self.train_fns,
num_parallel_reads=self.batch_size), \
tf.data.TFRecordDataset(self.val_fns,
num_parallel_reads=self.batch_size), \
tf.data.TFRecordDataset(self.test_fns,
num_parallel_reads=self.batch_size),
# TODO: Comparison/profiling runs
# TODO: parallel for batch size while that's small
# TODO: obj.decode_item might not work here - figure out runtime
# implementation based on wrapped function call that can be serialised
decoder = get_decoder(self.shape,
self.num_channels,
self.n_forecast_days,
dtype=self.dtype.__name__)
if self.shuffling:
logging.info("Training dataset(s) marked to be shuffled")
# FIXME: this is not a good calculation, but we don't have access
# in the mixin to the configuration that generated the dataset #57
train_ds = train_ds.shuffle(
min(int(len(self.train_fns) * self.batch_size), 366))
# Since TFRecordDataset does not parse or decode the dataset from bytes,
# use custom decoder function with map to do so.
train_ds = train_ds.\
map(decoder, num_parallel_calls=self.batch_size).\
batch(self.batch_size)
val_ds = val_ds.\
map(decoder, num_parallel_calls=self.batch_size).\
batch(self.batch_size)
test_ds = test_ds.\
map(decoder, num_parallel_calls=self.batch_size).\
batch(self.batch_size)
return train_ds.prefetch(tf.data.AUTOTUNE), \
val_ds.prefetch(tf.data.AUTOTUNE), \
test_ds.prefetch(tf.data.AUTOTUNE)
[docs]
def check_dataset(self, split: str = "train") -> None:
"""Check the dataset for NaN, log debugging info regarding dataset shape and bounds.
Also logs a warning if any NaN are found.
Args:
split: The split of the dataset to check. Default is "train".
"""
logging.debug("Checking dataset {}".format(split))
decoder = get_decoder(self.shape,
self.num_channels,
self.n_forecast_days,
dtype=self.dtype.__name__)
for df in getattr(self, "{}_fns".format(split)):
logging.debug("Getting records from {}".format(df))
try:
raw_dataset = tf.data.TFRecordDataset([df])
raw_dataset = raw_dataset.map(decoder)
for i, (x, y, sw) in enumerate(raw_dataset):
x = x.numpy()
y = y.numpy()
sw = sw.numpy()
logging.debug(
"Got record {}:{} with x {} y {} sw {}".format(
df, i, x.shape, y.shape, sw.shape))
input_nans = np.isnan(x).sum()
output_nans = np.isnan(y[sw > 0.]).sum()
input_min = np.min(x)
input_max = np.max(x)
output_min = np.min(x)
output_max = np.max(x)
sw_min = np.min(x)
sw_max = np.max(x)
logging.debug(
"Bounds: Input {}:{} Output {}:{} SW {}:{}".format(
input_min, input_max, output_min, output_max,
sw_min, sw_max))
if input_nans > 0:
logging.warning("Input NaNs detected in {}:{}".format(
df, i))
if output_nans > 0:
logging.warning(
"Output NaNs detected in {}:{}, not "
"accounted for by sample weighting".format(df, i))
except tf.errors.DataLossError as e:
logging.warning("{}: data loss error {}".format(df, e.message))
except tf.errors.OpError as e:
logging.warning("{}: tensorflow error {}".format(df, e.message))
# We don't except any non-tensorflow errors to prevent progression
@property
def batch_size(self) -> int:
"""The dataset's batch size."""
return self._batch_size
@property
def dtype(self) -> str:
"""The dataset's data type."""
return self._dtype
@property
def n_forecast_days(self) -> int:
"""The number of days to forecast in prediction."""
return self._n_forecast_days
@property
def num_channels(self) -> int:
"""The number of channels in dataset."""
return self._num_channels
@property
def shape(self) -> object:
"""The shape of dataset."""
return self._shape
@property
def shuffling(self) -> bool:
"""A flag for whether training dataset(s) are marked to be shuffled."""
return self._shuffling