How to write custom loss functions¶
The simplest possible loss function¶
from pytorch_metric_learning.losses import BaseMetricLossFunction
import torch
class BarebonesLoss(BaseMetricLossFunction):
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
# perform some calculation #
some_loss = torch.mean(embeddings)
# put into dictionary #
return {
"loss": {
"losses": some_loss,
"indices": None,
"reduction_type": "already_reduced",
}
}
Compatability with distances and reducers¶
You can make your loss function a lot more powerful by adding support for distance metrics and reducers:
from pytorch_metric_learning.losses import BaseMetricLossFunction
from pytorch_metric_learning.reducers import AvgNonZeroReducer
from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
import torch
class FullFeaturedLoss(BaseMetricLossFunction):
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
indices_tuple = lmu.convert_to_triplets(indices_tuple, labels)
anchors, positives, negatives = indices_tuple
if len(anchors) == 0:
return self.zero_losses()
mat = self.distance(embeddings)
ap_dists = mat[anchors, positives]
an_dists = mat[anchors, negatives]
# perform some calculations #
losses1 = ap_dists - an_dists
losses2 = ap_dists * 5
losses3 = torch.mean(embeddings)
# put into dictionary #
return {
"loss1": {
"losses": losses1,
"indices": indices_tuple,
"reduction_type": "triplet",
},
"loss2": {
"losses": losses2,
"indices": (anchors, positives),
"reduction_type": "pos_pair",
},
"loss3": {
"losses": losses3,
"indices": None,
"reduction_type": "already_reduced",
},
}
def get_default_reducer(self):
return AvgNonZeroReducer()
def get_default_distance(self):
return CosineSimilarity()
def _sub_loss_names(self):
return ["loss1", "loss2", "loss3"]
Here are a few details about this loss function:
- It operates on triplets, so
convert_to_triplets
is used to convertindices_tuple
to triplet form. self.distance
returns a pairwise distance matrix- The output of the loss function is a dictionary that contains multiple sub losses. This is why it overrides the
_sub_loss_names
function. get_default_reducer
is overriden to useAvgNonZeroReducer
by default, rather thanMeanReducer
.get_default_distance
is overriden to useCosineSimilarity
by default, rather thanLpDistances(p=2)
.
More on distances¶
To make your loss compatible with inverted distances (like cosine similarity), you can check self.distance.is_inverted
, and write whatever logic necessary to make your loss make sense in that context.
There are also a few functions in self.distance
that provide some of this logic, specifically self.distance.smallest_dist
, self.distance.largest_dist
, and self.distance.margin
. The function definitions are pretty straightforward, and you can find them here.
Using indices_tuple
¶
This is an optional argument passed in from the outside. (See the overview for an example.) It currently has 3 possible forms:
None
- A tuple of size 4, representing the indices of mined pairs (anchors, positives, anchors, negatives)
- A tuple of size 3, representing the indices of mined triplets (anchors, positives, negatives)
To use indices_tuple
, use the appropriate conversion function. You don't need to know what type will be passed in, as the conversion function takes care of that:
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
# For a pair based loss
# After conversion, indices_tuple will be a tuple of size 4
indices_tuple = lmu.convert_to_pairs(indices_tuple, labels)
# For a triplet based loss
# After conversion, indices_tuple will be a tuple of size 3
indices_tuple = lmu.convert_to_triplets(indices_tuple, labels)
# For a classification based loss
# miner_weights.shape == labels.shape
# You can use these to weight your loss
miner_weights = lmu.convert_to_weights(indices_tuple, labels, dtype=torch.float32)
Reduction type¶
The purpose of reduction types is to provide extra information to the reducer, if it needs it. For example, you could write a reducer that behaves differently depending on what kind of loss it receives. Here's a summary of each reduction type:
Reduction type | Meaning | Shape of "indices" |
---|---|---|
"triplet" | Each entry in "losses" represents a triplet. | A tuple of 3 tensors (anchors, positives, negatives), each of size (N,). |
"pos_pair" | Each entry in "losses" represents a positive pair. | A tuple of 2 tensors (anchors, positives), each of size (N,). |
"neg_pair" | Each entry in "losses" represents a negative pair. | A tuple of 2 tensors (anchors, negatives), each of size (N,). |
"element" | Each entry in "losses" represents something other than a tuple, e.g. an element in a batch. | A tensor of size (N,) |
"already_reduced" | "losses" is a single number, i.e. the loss has already been reduced. | Should be None |
Some useful examples to look at¶
Here are some existing loss functions that might be useful for reference: