2023-04-01 00:36:31 +02:00
|
|
|
import tensorflow as tf
|
|
|
|
from keras import backend
|
2023-04-17 06:19:03 +02:00
|
|
|
from keras.utils import losses_utils, tf_utils
|
|
|
|
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
def inv_kl_divergence(y_true, y_pred):
|
2023-04-17 06:19:03 +02:00
|
|
|
y_pred = tf.convert_to_tensor(y_pred)
|
|
|
|
y_true = tf.cast(y_true, y_pred.dtype)
|
|
|
|
y_true = backend.clip(y_true, backend.epsilon(), 1)
|
|
|
|
y_pred = backend.clip(y_pred, backend.epsilon(), 1)
|
|
|
|
return tf.reduce_sum(y_pred * tf.math.log(y_pred / y_true), axis=-1)
|
|
|
|
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
def masked_bce(y_true, y_pred):
|
2023-04-17 06:19:03 +02:00
|
|
|
y_true = tf.cast(y_true, dtype=tf.float32)
|
|
|
|
mask = y_true != -1
|
|
|
|
|
|
|
|
return tf.keras.metrics.binary_crossentropy(
|
|
|
|
tf.boolean_mask(y_true, mask), tf.boolean_mask(y_pred, mask)
|
|
|
|
)
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
|
|
|
|
class LossFunctionWrapper(tf.keras.losses.Loss):
|
2023-04-17 06:19:03 +02:00
|
|
|
def __init__(
|
|
|
|
self, fn, reduction=losses_utils.ReductionV2.AUTO, name=None, **kwargs
|
|
|
|
):
|
|
|
|
super().__init__(reduction=reduction, name=name)
|
|
|
|
self.fn = fn
|
|
|
|
self._fn_kwargs = kwargs
|
|
|
|
|
|
|
|
def call(self, y_true, y_pred):
|
|
|
|
if tf.is_tensor(y_pred) and tf.is_tensor(y_true):
|
|
|
|
y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true)
|
|
|
|
|
|
|
|
ag_fn = tf.__internal__.autograph.tf_convert(
|
|
|
|
self.fn, tf.__internal__.autograph.control_status_ctx()
|
|
|
|
)
|
|
|
|
return ag_fn(y_true, y_pred, **self._fn_kwargs)
|
|
|
|
|
|
|
|
def get_config(self):
|
|
|
|
config = {}
|
|
|
|
for k, v in self._fn_kwargs.items():
|
|
|
|
config[k] = backend.eval(v) if tf_utils.is_tensor_or_variable(v) else v
|
|
|
|
base_config = super().get_config()
|
|
|
|
return dict(list(base_config.items()) + list(config.items()))
|
|
|
|
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
class InvKLD(LossFunctionWrapper):
|
2023-04-17 06:19:03 +02:00
|
|
|
def __init__(
|
|
|
|
self, reduction=losses_utils.ReductionV2.AUTO, name="inv_kl_divergence"
|
|
|
|
):
|
|
|
|
super().__init__(inv_kl_divergence, name=name, reduction=reduction)
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
|
|
|
|
class MaskedBCE(LossFunctionWrapper):
|
2023-04-17 06:19:03 +02:00
|
|
|
def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name="masked_bce"):
|
|
|
|
super().__init__(masked_bce, name=name, reduction=reduction)
|