Skip to content

Accuracy Calculation

The AccuracyCalculator class computes several accuracy metrics given a query and reference embeddings. It can be easily extended to create custom accuracy metrics.

from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
AccuracyCalculator(include=(),
                    exclude=(),
                    avg_of_avgs=False,
                    return_per_class=False,
                    k=None,
                    label_comparison_fn=None,
                    device=None,
                    knn_func=None,
                    kmeans_func=None)

Parameters

  • include: Optional. A list or tuple of strings, which are the names of metrics you want to calculate. If left empty, all default metrics will be calculated.
  • exclude: Optional. A list or tuple of strings, which are the names of metrics you do not want to calculate.
  • avg_of_avgs: If True, the average accuracy per class is computed, and then the average of those averages is returned. This can be useful if your dataset has unbalanced classes. If False, the global average will be returned.
  • return_per_class: If True, the average accuracy per class is computed and returned.
  • k: The number of nearest neighbors that will be retrieved for metrics that require k-nearest neighbors. The allowed values are:
    • None. This means k will be set to the total number of reference embeddings.
    • An integer greater than 0. This means k will be set to the input integer.
    • "max_bin_count". This means k will be set to max(bincount(reference_labels)) - self_count where self_count == 1 if the query and reference embeddings come from the same source.
  • label_comparison_fn: A function that compares two torch arrays of labels and returns a boolean array. The default is torch.eq. If a custom function is used, then you must exclude clustering based metrics ("NMI" and "AMI"). The example below shows a custom function for two-dimensional labels. It returns True if the 0th column matches, and the 1st column does not match.
  • device: The device to move input tensors to. If None, will default to GPUs if available.
  • knn_func: A callable that takes in 4 arguments (query, k, reference, ref_includes_query) and returns distances, indices. Default is pytorch_metric_learning.utils.inference.FaissKNN.
  • kmeans_func: A callable that takes in 2 arguments (x, nmb_clusters) and returns a 1-d tensor of cluster assignments. Default is pytorch_metric_learning.utils.inference.FaissKMeans.
    from pytorch_metric_learning.distances import SNRDistance
    from pytorch_metric_learning.utils.inference import CustomKNN
    
    def example_label_comparison_fn(x, y):
        return (x[:, 0] == y[:, 0]) & (x[:, 1] != y[:, 1])
    
    knn_func = CustomKNN(SNRDistance())
    AccuracyCalculator(exclude=("NMI", "AMI"), 
                        label_comparison_fn=example_label_comparison_fn,
                        knn_func=knn_func)
    

Getting accuracy

Call the get_accuracy method to obtain a dictionary of accuracies.

def get_accuracy(self, 
    query,
    query_labels,       
    reference=None,
    reference_labels=None, 
    ref_includes_query=False, 
    include=(),
    exclude=()
):
# returns a dictionary mapping from metric names to accuracy values
# The default metrics are:
# "NMI" (Normalized Mutual Information)
# "AMI" (Adjusted Mutual Information)
# "precision_at_1"
# "r_precision"
# "mean_average_precision_at_r"

  • query: A 2D torch or numpy array of size (Nq, D), where Nq is the number of query samples. For each query sample, nearest neighbors are retrieved and accuracy is computed.
  • query_labels: A 1D torch or numpy array of size (Nq). Each element should be an integer representing the sample's label.
  • reference: A 2D torch or numpy array of size (Nr, D), where Nr is the number of reference samples. This is where nearest neighbors are retrieved from.
  • reference_labels: A 1D torch or numpy array of size (Nr). Each element should be an integer representing the sample's label.
  • ref_includes_query: Set to True if query is a subset of reference or if query is reference. Set to False otherwise.
  • include: Optional. A list or tuple of strings, which are the names of metrics you want to calculate. If left empty, all metrics specified during initialization will be calculated.
  • exclude: Optional. A list or tuple of strings, which are the names of metrics you do not want to calculate.

Note that labels can be 2D if a custom label comparison function is used.

Lone query labels

If some query labels don't appear in the reference set, then it's impossible for those labels to have non-zero k-nn accuracy. Zero accuracy for these labels doesn't indicate anything about the quality of the embedding space. So these lone query labels are excluded from k-nn based accuracy calculations.

For example, if the input query_labels is [0,0,1,1] and reference_labels is [1,1,1,2,2], then 0 is considered a lone query label.

CPU/GPU usage

  • If you installed faiss-cpu then the CPU will always be used.
  • If you installed faiss-gpu, then the GPU will be used if k <= 1024 for CUDA < 9.5, and k <= 2048 for CUDA >= 9.5. If this condition is not met, then the CPU will be used.

If your dataset is large, you might find the k-nn search is very slow. This is because the default behavior is to set k to len(reference_embeddings). To avoid this, you can set k to a number, like k = 1000, or try k = "max_bin_count".

Explanations of the default accuracy metrics

Important note

AccuracyCalculator's mean_average_precision_at_r and r_precision are correct only if k = None, or k = "max_bin_count", or k >= max(bincount(reference_labels))

Adding custom accuracy metrics

Let's say you want to use the existing metrics but also compute precision @ 2, and a fancy mutual info method. You can extend the existing class, and write methods that start with the keyword calculate_

from pytorch_metric_learning.utils import accuracy_calculator

class YourCalculator(accuracy_calculator.AccuracyCalculator):
    def calculate_precision_at_2(self, knn_labels, query_labels, **kwargs):
        return accuracy_calculator.precision_at_k(knn_labels, query_labels[:, None], 2)

    def calculate_fancy_mutual_info(self, query_labels, cluster_labels, **kwargs):
        return fancy_computations

    def requires_clustering(self):
        return super().requires_clustering() + ["fancy_mutual_info"] 

    def requires_knn(self):
        return super().requires_knn() + ["precision_at_2"] 

Any method that starts with "calculate_" will be passed the following kwargs:

kwargs = {"query": query,                    # query embeddings
    "reference": reference,                  # reference embeddings
    "query_labels": query_labels,        
    "reference_labels": reference_labels,
    "ref_includes_query": e}  # True if query is reference, or if query is a subset of reference.

If your method requires a k-nearest neighbors search, then append your method's name to the requires_knn list, as shown in the above example. If any of your accuracy methods require k-nearest neighbors, they will also receive the following kwargs:

    {"label_counts": label_counts,           # A dictionary mapping from reference labels to the number of times they occur
    "knn_labels": knn_labels,                # A 2d array where each row is the labels of the nearest neighbors of each query. The neighbors are retrieved from the reference set
    "knn_distances": knn_distances           # The euclidean distance corresponding to each k-nearest neighbor in knn_labels
    "lone_query_labels": lone_query_labels   # The set of labels (in the form of a torch array) that have only 1 occurrence in reference_labels
    "not_lone_query_mask": not_lone_query_mask} # A boolean mask, where True means that a query element has at least 1 possible neighbor in reference.           

If your method requires cluster labels, then append your method's name to the requires_clustering list, as shown in the above example. Then, if any of your methods need cluster labels, self.get_cluster_labels() will be called, and the kwargs will include:

    {"cluster_labels": cluster_labels} # A 1D array with a cluster label for each element in the query embeddings.

Now when get_accuracy is called, the returned dictionary will contain precision_at_2 and fancy_mutual_info:

calculator = YourCalculator()
acc_dict = calculator.get_accuracy(query_embeddings,
    query_labels,
    reference_embeddings,
    reference_labels,
    ref_includes_query=True
)
# Now acc_dict contains the metrics "precision_at_2" and "fancy_mutual_info"
# in addition to the original metrics from AccuracyCalculator

You can use your custom calculator with the tester classes as well, by passing it in as an init argument. (By default, the testers use AccuracyCalculator.)

from pytorch_metric_learning import testers
t = testers.GlobalEmbeddingSpaceTester(..., accuracy_calculator=YourCalculator())

Using a custom label comparison function

If you define your own label_comparison_fn, then query_labels and reference_labels can be 1D or 2D, and consist of integers or floating point numbers, as long as your label_comparison_fn can process them.

Example of 2D labels:

def label_comparison_fn(x, y):
    return (x[..., 0] == y[..., 0]) & (x[..., 1] != y[..., 1])

# these are valid labels
labels = torch.tensor([
    (1, 3),
    (7, 4),
    (1, 4),
    (1, 5),
    (1, 6),
])

Example of floating point labels:

def label_comparison_fn(x, y):
    return torch.abs(x - y) < 1

# these are valid labels
labels = torch.tensor([
    10.0,
    0.03,
    0.04,
    0.05,
])