icenet.model package¶
Submodules¶
icenet.model.callbacks module¶
- class icenet.model.callbacks.BatchwiseModelCheckpoint(save_frequency: object, model_path: object, mode: object, monitor: object, prev_best: object = None, sample_at_zero: object = False)[source]¶
Bases:
Callback
- Parameters:
save_frequency
model_path
mode
monitor
prev_best
sample_at_zero
- class icenet.model.callbacks.IceNetPreTrainingEvaluator(*args, validation_frequency: int, val_dataloader: object, sample_at_zero: bool = False, **kwargs)[source]¶
Bases:
Callback
Custom tf.keras callback to update the logs dict used by all other callbacks with the validation set metrics. The callback is executed every validation_frequency batches.
This can be used in conjuction with the BatchwiseModelCheckpoint callback to perform a model checkpoint based on validation data every N batches - ensure the save_frequency input to BatchwiseModelCheckpoint is also set to validation_frequency.
Also ensure that the callbacks list past to Model.fit() contains this callback before any other callbacks that need the validation metrics.
Also use Weights and Biases to log the training and validation metrics.
TODO: not sure this is really necessary/stable, review
- Parameters:
validation_frequency
val_dataloader
sample_at_zero
icenet.model.losses module¶
icenet.model.metrics module¶
- class icenet.model.metrics.ConstructLeadtimeAccuracy(*args, **kwargs)[source]¶
Bases:
CategoricalAccuracy
Computes the network’s accuracy over the active grid cell region for either a) a specific lead time in months, or b) over all lead times at once.
- Parameters:
name
use_all_forecast_months
single_forecast_leadtime_idx
- class icenet.model.metrics.WeightedBinaryAccuracy(*args, **kwargs)[source]¶
Bases:
BinaryAccuracy
- Parameters:
leadtime_idx
- update_state(y_true: object, y_pred: object, sample_weight: object = None)[source]¶
Custom keras loss/metric for binary accuracy in classifying SIC>15%
- Parameters:
y_true – Ground truth outputs
y_pred – Network predictions
sample_weight – Pixelwise mask weighting for metric summation
- Returns:
Root mean squared error of SIC (%) (float)
- class icenet.model.metrics.WeightedMAE(*args, **kwargs)[source]¶
Bases:
MeanAbsoluteError
Custom keras loss/metric for mean absolute error
- Parameters:
name
leadtime_idx
- class icenet.model.metrics.WeightedMSE(*args, **kwargs)[source]¶
Bases:
MeanSquaredError
Custom keras loss/metric for mean squared error
- Parameters:
leadtime_idx
name
icenet.model.models module¶
- class icenet.model.models.TemperatureScale(*args, **kwargs)[source]¶
Bases:
Layer
Temperature scaling layer
Implements the temperature scaling layer for probability calibration, as introduced in Guo 2017 (http://proceedings.mlr.press/v70/guo17a.html).
- icenet.model.models.linear_trend_forecast(usable_selector: object, forecast_date: object, da: object, mask: object, missing_dates: object = (), shape: object = (432, 432)) object [source]¶
- Parameters:
usable_selector
forecast_date
da
mask
missing_dates
shape
- Returns:
- icenet.model.models.unet_batchnorm(input_shape: object, loss: object, metrics: object, learning_rate: float = 0.0001, filter_size: float = 3, n_filters_factor: float = 1, n_forecast_days: int = 1, legacy_rounding: bool = False) object [source]¶
- Parameters:
input_shape
loss
metrics
learning_rate
filter_size
n_filters_factor
n_forecast_days
legacy_rounding – Ensures filter number calculations are int()’d at the end of calculations
- Returns:
icenet.model.predict module¶
- icenet.model.predict.predict_forecast(dataset_config: object, network_name: object, dataset_name: object = None, legacy_rounding: bool = False, model_func: callable = <function unet_batchnorm>, n_filters_factor: float = 0.125, network_folder: object = None, output_folder: object = None, save_args: bool = False, seed: int = 42, start_dates: object = (datetime.date(2025, 2, 3), ), test_set: bool = False) object [source]¶
- Parameters:
dataset_config
network_name
dataset_name
legacy_rounding
model_func
n_filters_factor
network_folder
output_folder
save_args
seed
start_dates
test_set
- Returns:
icenet.model.train module¶
- icenet.model.train.evaluate_model(model_path: object, dataset: object, dataset_ratio: float = 1.0, max_queue_size: int = 3, workers: int = 5, use_multiprocessing: bool = True)[source]¶
- Parameters:
model_path
dataset
dataset_ratio
max_queue_size
workers
use_multiprocessing
- icenet.model.train.train_model(run_name: object, dataset: object, callback_objects: list = [], checkpoint_monitor: str = 'val_rmse', checkpoint_mode: str = 'min', dataset_ratio: float = 1.0, early_stopping_patience: int = 30, epochs: int = 2, filter_size: float = 3, learning_rate: float = 0.0001, lr_10e_decay_fac: float = 1.0, lr_decay_start: float = 10, lr_decay_end: float = 30, max_queue_size: int = 3, model_func: object = <function unet_batchnorm>, n_filters_factor: float = 2, network_folder: object = None, network_save: bool = True, pickup_weights: bool = False, pre_load_network: bool = False, pre_load_path: object = None, seed: int = 42, strategy: object = <tensorflow.python.distribute.distribute_lib._DefaultDistributionStrategy object>, training_verbosity: int = 1, workers: int = 5, use_multiprocessing: bool = True, use_tensorboard: bool = True) object [source]¶
- Parameters:
run_name
dataset
callback_objects
checkpoint_monitor
checkpoint_mode
dataset_ratio
early_stopping_patience
epochs
filter_size
learning_rate
lr_10e_decay_fac
lr_decay_start
lr_decay_end
max_queue_size
model_func
n_filters_factor
network_folder
network_save
pickup_weights
pre_load_network
pre_load_path
seed
strategy
training_verbosity
workers
use_multiprocessing
use_tensorboard
- Returns:
icenet.model.utils module¶
- icenet.model.utils.arr_to_ice_edge_arr(arr: object, thresh: object, land_mask: object, region_mask: object) object [source]¶
Compute a boolean mask with True over ice edge contour grid cells using matplotlib.pyplot.contour and an input threshold to define the ice edge (e.g. 0.15 for the 15% SIC ice edge or 0.5 for SIP forecasts). The contour along the coastline is removed using the region mask. :param arr: :param thresh: :param land_mask: :param region_mask: :return:
- icenet.model.utils.arr_to_ice_edge_rgba_arr(arr: object, thresh: object, land_mask: object, region_mask: object, rgb: object) object [source]¶
- Parameters:
arr
thresh
land_mask
region_mask
rgb
- Returns: