import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, BatchNormalization, UpSampling2D, \
concatenate, MaxPooling2D, Input
from tensorflow.keras.optimizers import Adam
"""
Defines the Python-based sea ice forecasting models, such as the IceNet architecture
and the linear trend extrapolation model.
"""
[docs]
@tf.keras.utils.register_keras_serializable()
class TemperatureScale(tf.keras.layers.Layer):
"""Temperature scaling layer
Implements the temperature scaling layer for probability calibration,
as introduced in Guo 2017 (http://proceedings.mlr.press/v70/guo17a.html).
"""
def __init__(self, **kwargs):
super(TemperatureScale, self).__init__(**kwargs)
self.temp = tf.Variable(initial_value=1.0,
trainable=False,
dtype=tf.float32,
name='temp')
[docs]
def call(self, inputs: object, **kwargs):
""" Divide the input logits by the T value.
:param **kwargs:
:param inputs:
:return:
"""
return tf.divide(inputs, self.temp)
[docs]
def get_config(self):
""" For saving and loading networks with this custom layer.
:return:
"""
return {'temp': self.temp.numpy()}
### Network architectures:
# --------------------------------------------------------------------
[docs]
def unet_batchnorm(input_shape: object,
loss: object,
metrics: object,
learning_rate: float = 1e-4,
filter_size: float = 3,
n_filters_factor: float = 1,
n_forecast_days: int = 1,
legacy_rounding: bool = False) -> object:
"""
:param input_shape:
:param loss:
:param metrics:
:param learning_rate:
:param filter_size:
:param n_filters_factor:
:param n_forecast_days:
:param legacy_rounding: Ensures filter number calculations are int()'d at the end of calculations
:return:
"""
inputs = Input(shape=input_shape)
start_out_channels = 64
reduced_channels = start_out_channels * n_filters_factor
if not legacy_rounding:
# We're assuming to just strip off any partial channels, rather than round
reduced_channels = int(reduced_channels)
channels = {
start_out_channels * 2 ** pow:
reduced_channels * 2 ** pow if not legacy_rounding else int(reduced_channels * 2 ** pow)
for pow in range(4)
}
conv1 = Conv2D(channels[64],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(inputs)
conv1 = Conv2D(channels[64],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv1)
bn1 = BatchNormalization(axis=-1)(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(bn1)
conv2 = Conv2D(channels[128],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(pool1)
conv2 = Conv2D(channels[128],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv2)
bn2 = BatchNormalization(axis=-1)(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(bn2)
conv3 = Conv2D(channels[256],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(pool2)
conv3 = Conv2D(channels[256],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv3)
bn3 = BatchNormalization(axis=-1)(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(bn3)
conv4 = Conv2D(channels[256],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(pool3)
conv4 = Conv2D(channels[256],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv4)
bn4 = BatchNormalization(axis=-1)(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(bn4)
conv5 = Conv2D(channels[512],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(pool4)
conv5 = Conv2D(channels[512],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv5)
bn5 = BatchNormalization(axis=-1)(conv5)
up6 = Conv2D(channels[256],
2,
activation='relu',
padding='same',
kernel_initializer='he_normal')(UpSampling2D(
size=(2, 2), interpolation='nearest')(bn5))
merge6 = concatenate([bn4, up6], axis=3)
conv6 = Conv2D(channels[256],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(merge6)
conv6 = Conv2D(channels[256],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv6)
bn6 = BatchNormalization(axis=-1)(conv6)
up7 = Conv2D(channels[256],
2,
activation='relu',
padding='same',
kernel_initializer='he_normal')(UpSampling2D(
size=(2, 2), interpolation='nearest')(bn6))
merge7 = concatenate([bn3, up7], axis=3)
conv7 = Conv2D(channels[256],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(merge7)
conv7 = Conv2D(channels[256],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv7)
bn7 = BatchNormalization(axis=-1)(conv7)
up8 = Conv2D(channels[128],
2,
activation='relu',
padding='same',
kernel_initializer='he_normal')(UpSampling2D(
size=(2, 2), interpolation='nearest')(bn7))
merge8 = concatenate([bn2, up8], axis=3)
conv8 = Conv2D(channels[128],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(merge8)
conv8 = Conv2D(channels[128],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv8)
bn8 = BatchNormalization(axis=-1)(conv8)
up9 = Conv2D(channels[64],
2,
activation='relu',
padding='same',
kernel_initializer='he_normal')(UpSampling2D(
size=(2, 2), interpolation='nearest')(bn8))
merge9 = concatenate([conv1, up9], axis=3)
conv9 = Conv2D(channels[64],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(merge9)
conv9 = Conv2D(channels[64],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv9)
conv9 = Conv2D(channels[64],
filter_size,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv9)
final_layer = Conv2D(n_forecast_days, kernel_size=1,
activation='sigmoid')(conv9)
# Keras graph mode needs y_pred and y_true to have the same shape, so we
# we must pad an extra dimension onto the model output to train with
# an extra sample weight dimension in y_true.
# final_layer = tf.expand_dims(final_layer, axis=-1)
model = Model(inputs, final_layer)
model.compile(optimizer=Adam(learning_rate=learning_rate),
loss=loss,
weighted_metrics=metrics)
return model
[docs]
def linear_trend_forecast(
usable_selector: object,
forecast_date: object,
da: object,
mask: object,
missing_dates: object = (),
shape: object = (432, 432)
) -> object:
"""
:param usable_selector:
:param forecast_date:
:param da:
:param mask:
:param missing_dates:
:param shape:
:return:
"""
usable_data = usable_selector(da, forecast_date, missing_dates)
if len(usable_data.time) < 1:
return np.full(shape, np.nan)
x = np.arange(len(usable_data.time))
y = usable_data.data.reshape(len(usable_data.time), -1)
src = np.c_[x, np.ones_like(x)]
r = np.linalg.lstsq(src, y, rcond=None)[0]
output_map = np.matmul(np.array([len(usable_data.time), 1]),
r).reshape(*shape)
output_map[mask] = 0.
output_map[output_map < 0] = 0.
output_map[output_map > 1] = 1.
return output_map