Skip to content

Distances

Distance classes compute pairwise distances/similarities between input embeddings.

Consider the TripletMarginLoss in its default form:

from pytorch_metric_learning.losses import TripletMarginLoss
loss_func = TripletMarginLoss(margin=0.2)
This loss function attempts to minimize [dap - dan + margin]+.

Typically, dap and dan represent Euclidean or L2 distances. But what if we want to use a squared L2 distance, or an unnormalized L1 distance, or a completely different distance measure like signal-to-noise ratio? With the distances module, you can try out these ideas easily:

### TripletMarginLoss with squared L2 distance ###
from pytorch_metric_learning.distances import LpDistance
loss_func = TripletMarginLoss(margin=0.2, distance=LpDistance(power=2))

### TripletMarginLoss with unnormalized L1 distance ###
loss_func = TripletMarginLoss(margin=0.2, distance=LpDistance(normalize_embeddings=False, p=1))

### TripletMarginLoss with signal-to-noise ratio###
from pytorch_metric_learning.distances import SNRDistance
loss_func = TripletMarginLoss(margin=0.2, distance=SNRDistance())

You can also use similarity measures rather than distances, and the loss function will make the necessary adjustments:

### TripletMarginLoss with cosine similarity##
from pytorch_metric_learning.distances import CosineSimilarity
loss_func = TripletMarginLoss(margin=0.2, distance=CosineSimilarity())
With a similarity measure, the TripletMarginLoss internally swaps the anchor-positive and anchor-negative terms: [san - sap + margin]+. In other words, it will try to make the anchor-negative similarities smaller than the anchor-positive similarities.

All losses, miners, and regularizers accept a distance argument. So you can try out the MultiSimilarityMiner using SNRDistance, or the NTXentLoss using LpDistance(p=1) and so on. Note that some losses/miners/regularizers have restrictions on the type of distances they can accept. For example, some classification losses only allow CosineSimilarity or DotProductSimilarity as their distance measure between embeddings and weights. To view restrictions for specific loss functions, see the losses page

BaseDistance

All distances extend this class and therefore inherit its __init__ parameters.

distances.BaseDistance(collect_stats = False,
                        normalize_embeddings=True, 
                        p=2, 
                        power=1, 
                        is_inverted=False)

Parameters:

  • collect_stats: If True, will collect various statistics that may be useful to analyze during experiments. If False, these computations will be skipped. Want to make True the default? Set the global COLLECT_STATS flag.
  • normalize_embeddings: If True, embeddings will be normalized to have an Lp norm of 1, before the distance/similarity matrix is computed.
  • p: The distance norm.
  • power: If not 1, each element of the distance/similarity matrix will be raised to this power.
  • is_inverted: Should be set by child classes. If False, then small values represent embeddings that are close together. If True, then large values represent embeddings that are similar to each other.

Required Implementations:

# Must return a matrix where mat[j,k] represents 
# the distance/similarity between query_emb[j] and ref_emb[k]
def compute_mat(self, query_emb, ref_emb):
    raise NotImplementedError

# Must return a tensor where output[j] represents 
# the distance/similarity between query_emb[j] and ref_emb[j]
def pairwise_distance(self, query_emb, ref_emb):
    raise NotImplementedError

BatchedDistance

Computes distance matrices iteratively, passing each matrix into iter_fn.

distances.BatchedDistance(distance, iter_fn=None, batch_size=32)

Parameters:

  • distance: The wrapped distance function.
  • iter_fn: This function will be called at every iteration. It receives (mat, s, e) as input, where mat is the current distance matrix, and s, e is the range of query embeddings used to construct mat.
  • batch_size: Each distance matrix will be size (batch_size, len(ref_emb)).

Example usage:

from pytorch_metric_learning.distances import BatchedDistance, CosineSimilarity

def fn(mat, s, e):
    print(f"At query indices {s}:{e}")

distance = BatchedDistance(CosineSimilarity(), fn)

# Works like a regular distance function, except nothing is returned.
# So any persistent changes need to be done in the supplied iter_fn.
# query vs query
distance(embeddings)
# query vs ref
distance(embeddings, ref_emb)

CosineSimilarity

distances.CosineSimilarity(**kwargs)

The returned mat[i,j] is the cosine similarity between query_emb[i] and ref_emb[j]. This class is equivalent to DotProductSimilarity(normalize_embeddings=True).

DotProductSimilarity

distances.DotProductSimilarity(**kwargs)
The returned mat[i,j] is equal to torch.sum(query_emb[i] * ref_emb[j])

LpDistance

distances.LpDistance(**kwargs)
The returned mat[i,j] is the Lp distance between query_emb[i] and ref_emb[j]. With default parameters, this is the Euclidean distance.

SNRDistance

Signal-to-Noise Ratio: A Robust Distance Metric for Deep Metric Learning

distances.SNRDistance(**kwargs)
The returned mat[i,j] is equal to:

torch.var(query_emb[i] - ref_emb[j]) / torch.var(query_emb[i])