Inference Models¶
utils.inference contains classes that make it convenient to find matching pairs within a batch, or from a set of pairs. Take a look at this notebook to see example usage.
InferenceModel¶
from pytorch_metric_learning.utils.inference import InferenceModel
InferenceModel(trunk,
embedder=None,
match_finder=None,
indexer=None,
normalize_embeddings=True,
batch_size=64)
Parameters:
- trunk: Your trained model for computing embeddings.
- embedder: Optional. This is if your model is split into two components (trunk and embedder). If None, then the embedder will simply return the trunk's output.
- match_finder: A MatchFinder object. If
None
, it will be set toMatchFinder(distance=CosineSimilarity(), threshold=0.9)
. - indexer: The object used for computing k-nearest-neighbors. If
None
, it will be set toFaissIndexer()
. - normalize_embeddings: If True, embeddings will be normalized to have Euclidean norm of 1.
- batch_size: The batch size used to compute embeddings, when training the indexer for k-nearest-neighbor retrieval.
MatchFinder¶
from pytorch_metric_learning.utils.inference import MatchFinder
MatchFinder(distance=None, threshold=None)
Parameters:
- distance: A distance object.
- threshold: Optional. Pairs will be a match if they fall under this threshold for non-inverted distances, or over this value for inverted distances. If not provided, then a threshold must be provided during function calls.
FaissIndexer¶
This will create a faiss index, specifically IndexFlatL2
, which can be used for nearest neighbor retrieval.
from pytorch_metric_learning.utils.inference import FaissIndexer
FaissIndexer()
LogitGetter¶
This class makes it easier to extract logits from classifier loss functions. Although "metric learning" usually means that you use embeddings during inference, there might be cases where you want to use the class logits instead of the embeddings.
from pytorch_metric_learning.utils.inference import LogitGetter
LogitGetter(
classifier,
layer_name=None,
transpose=None,
distance=None,
copy_weights=True,
):
Parameters:
- classifier: The classifier loss function that you want to extract logits from.
- layer_name: Optional. The attribute name of the weight matrix inside the classifier. If not specified, each of the following will be tried:
["fc", "proxies", "W"]
. - transpose: Optional. Whether or not to transpose the weight matrix. If the weight matrix is of size
(embedding_size, num_classes)
, then it should be transposed. If not specified, then transposing will be done automatically during theforward
call if necessary, based on the shapes of the input embeddings and the weight matrix. (Note that it is only guaranteed to make the correct transposing decision ifnum_classes != embedding_size
.) - distance: Optional. A distance object which will compute the distance or similarity matrix, i.e. the logits. If not specified, then
classifier.distance
will be used. - copy_weights: If True, then LogitGetter will contain a copy of (instead of a reference to) the classifier weights, so that if you update the classifier weights, the LogitGetter remains unchanged.
Example usage:
from pytorch_metric_learning.losses import ArcFaceLoss
from pytorch_metric_learning.utils.inference import LogitGetter
loss_fn = ArcFaceLoss(num_classes = 100, embedding_size = 512)
LG = LogitGetter(loss_fn)
logits = LG(embeddings)