2023-03-31 20:05:14 +02:00
|
|
|
"""Loss functions -- including multi task ones."""
|
|
|
|
|
|
|
|
import typing
|
|
|
|
|
|
|
|
from tml.core.loss_type import LossType
|
|
|
|
from tml.ml_logging.torch_logging import logging
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
def _maybe_warn(reduction: str):
|
|
|
|
"""
|
2023-09-11 12:56:29 +02:00
|
|
|
Emit a warning if the reduction method is different from 'mean'.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
reduction (str): The reduction method being used.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
if reduction != "mean":
|
|
|
|
logging.warn(
|
|
|
|
f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal,"
|
|
|
|
f"to the gradient without DDP only for mean reduction. If you need this property for"
|
|
|
|
f"the provided reduction {reduction}, it needs to be implemented."
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def build_loss(
|
|
|
|
loss_type: LossType,
|
|
|
|
reduction="mean",
|
|
|
|
):
|
2023-09-11 12:56:29 +02:00
|
|
|
"""
|
|
|
|
Build a loss function based on the specified loss type and reduction method.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
loss_type (LossType): The type of loss to build.
|
|
|
|
reduction (str): The reduction method for the loss (default: 'mean').
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Callable: A loss function that takes logits and labels as input.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
_maybe_warn(reduction)
|
|
|
|
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
|
|
|
|
|
|
|
|
def loss_fn(logits, labels):
|
|
|
|
return f(logits, labels.type_as(logits), reduction=reduction)
|
|
|
|
|
|
|
|
return loss_fn
|
|
|
|
|
|
|
|
|
|
|
|
def get_global_loss_detached(local_loss, reduction="mean"):
|
|
|
|
"""
|
2023-09-11 12:56:29 +02:00
|
|
|
Perform all_reduce to obtain the global loss function using the provided reduction.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
local_loss (torch.Tensor): The local loss of the current rank.
|
|
|
|
reduction (str): The reduction to use for all_reduce. Should match the reduction used by DDP.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: The reduced and detached global loss.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
if reduction != "mean":
|
|
|
|
logging.warn(
|
|
|
|
f"The reduction used in this function should be the same as the one used by "
|
|
|
|
f"the DDP model. By default DDP uses mean, So ensure that DDP is appropriately"
|
|
|
|
f"modified for reduction {reduction}."
|
|
|
|
)
|
|
|
|
|
|
|
|
if reduction not in ["mean", "sum"]:
|
|
|
|
raise ValueError(f"Reduction {reduction} is currently unsupported.")
|
|
|
|
|
|
|
|
global_loss = local_loss.detach()
|
|
|
|
|
|
|
|
if reduction == "mean":
|
|
|
|
global_loss.div_(torch.distributed.get_world_size())
|
|
|
|
|
|
|
|
torch.distributed.all_reduce(global_loss)
|
|
|
|
return global_loss
|
|
|
|
|
|
|
|
|
|
|
|
def build_multi_task_loss(
|
|
|
|
loss_type: LossType,
|
|
|
|
tasks: typing.List[str],
|
|
|
|
task_loss_reduction="mean",
|
|
|
|
global_reduction="mean",
|
|
|
|
pos_weights=None,
|
|
|
|
):
|
2023-09-11 12:56:29 +02:00
|
|
|
"""
|
|
|
|
Build a multi-task loss function based on the specified loss type and configurations.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
loss_type (LossType): The type of loss to build.
|
|
|
|
tasks (typing.List[str]): List of task names.
|
|
|
|
task_loss_reduction (str): Reduction method for task-specific losses (default: 'mean').
|
|
|
|
global_reduction (str): Reduction method for the global loss (default: 'mean').
|
|
|
|
pos_weights (Optional): Positive class weights for tasks (default: None).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Callable: A multi-task loss function that takes logits, labels, and weights as input.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
_maybe_warn(global_reduction)
|
|
|
|
_maybe_warn(task_loss_reduction)
|
|
|
|
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
|
|
|
|
|
|
|
|
loss_reduction_fns = {
|
|
|
|
"mean": torch.mean,
|
|
|
|
"sum": torch.sum,
|
|
|
|
"min": torch.min,
|
|
|
|
"max": torch.max,
|
|
|
|
"median": torch.median,
|
|
|
|
}
|
|
|
|
|
|
|
|
def loss_fn(logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor):
|
|
|
|
if pos_weights is None:
|
|
|
|
torch_weights = torch.ones([len(tasks)])
|
|
|
|
else:
|
|
|
|
torch_weights = torch.tensor(pos_weights)
|
|
|
|
|
|
|
|
losses = {}
|
|
|
|
for task_idx, task in enumerate(tasks):
|
|
|
|
task_logits = logits[:, task_idx]
|
|
|
|
label = labels[:, task_idx].type_as(task_logits)
|
|
|
|
|
|
|
|
loss = f(
|
|
|
|
task_logits,
|
|
|
|
label,
|
|
|
|
reduction=task_loss_reduction,
|
|
|
|
pos_weight=torch_weights[task_idx],
|
|
|
|
weight=weights[:, task_idx],
|
|
|
|
)
|
|
|
|
losses[f"loss/{task}"] = loss
|
|
|
|
|
|
|
|
losses["loss"] = loss_reduction_fns[global_reduction](torch.stack(list(losses.values())))
|
|
|
|
return losses
|
|
|
|
|
|
|
|
return loss_fn
|
|
|
|
|
|
|
|
|
|
|
|
_LOSS_TYPE_TO_FUNCTION = {
|
|
|
|
LossType.BCE_WITH_LOGITS: torch.nn.functional.binary_cross_entropy_with_logits
|
|
|
|
}
|