Distributed¶
Wrap a tuple loss or miner with these when using PyTorch's DistributedDataParallel (i.e. multiprocessing).
DistributedLossWrapper¶
utils.distributed.DistributedLossWrapper(loss, efficient=False)
Parameters:
- loss: The loss function to wrap
- efficient:
True
: each process uses its own embeddings for anchors, and the gathered embeddings for positives/negatives. Gradients will not be equal to those in non-distributed code, but the benefit is reduced memory and faster training.False
: each process uses gathered embeddings for both anchors and positives/negatives. Gradients will be equal to those in non-distributed code, but at the cost of doing unnecessary operations (i.e. doing computations where both anchors and positives/negatives have no gradient).
Example usage:
from pytorch_metric_learning import losses
from pytorch_metric_learning.utils import distributed as pml_dist
loss_func = losses.ContrastiveLoss()
loss_func = pml_dist.DistributedLossWrapper(loss_func)
# in each process during training
loss = loss_func(embeddings, labels)
DistributedMinerWrapper¶
utils.distributed.DistributedMinerWrapper(miner, efficient=False)
Parameters:
- miner: The miner to wrap
- efficient: If your distributed loss function has
efficient=True
then you must also set the distributed miner'sefficient
to True.
Example usage:
from pytorch_metric_learning import miners
from pytorch_metric_learning.utils import distributed as pml_dist
miner = miners.MultiSimilarityMiner()
miner = pml_dist.DistributedMinerWrapper(miner)
# in each process
tuples = miner(embeddings, labels)
# pass into a DistributedLossWrapper
loss = loss_func(embeddings, labels, indices_tuple)