Source code for icenet.model.losses

import tensorflow as tf


[docs] class WeightedMSE(tf.keras.losses.MeanSquaredError): """Custom keras loss/metric for mean squared error :param name: """ def __init__(self, name: str = 'mse', **kwargs): super().__init__(name=name, **kwargs) def __call__(self, y_true: object, y_pred: object, sample_weight: object = None): """ :param y_true: Ground truth outputs :param y_pred: Network predictions :param sample_weight: Pixelwise mask weighting for metric summation :return: Mean squared error of SIC (%) (float) """ # TF automatically reduces along final dimension - include dummy axis y_true = tf.expand_dims(y_true, axis=-1) y_pred = tf.expand_dims(y_pred, axis=-1) # if sample_weight is not None: # sample_weight = tf.expand_dims(sample_weight, axis=-1) return super().__call__(100 * y_true, 100 * y_pred, sample_weight=sample_weight)