Skip to content

Adding Custom Modules

Register your own classes and modules

By default, this library gives you access to various classes in pytorch-metric-learning, torch, torchvision, and pretrainedmodels.

Let's say you want to use your own loss function as well as a custom optimizer that isn't available in torch.optim. You can accomplish this by replacing the last two lines of the example script with this:

from your_own_loss import YourLossFunction
from custom_optimizer import CoolOptimizer

r = runner(**(args.__dict__))

# make the runner aware of them
r.register("loss", YourLossFunction)
r.register("optimizer", CoolOptimizer)
r.run()

Now you can access your custom classes just like any other class:

loss_funcs:
  metric_loss: 
    YourLossFunction:

optimizers:
  trunk_optimizer:
    CoolOptimizer:
      lr: 0.01

If you have a module containing multiple classes and you want to register all those classes, you can simply register the module:

import YourModuleOfLosses
r.register("loss", YourModuleOfLosses)

Registering your own trainer is a bit more involved, because you need to also create an associated API parser. The name of the api parser should be APIParser<name of your training method>.

Here's an example where I make a trainer that extends trainers.MetricLossOnly, and takes in an additional argument foo. If foo is a simple parameter that can be specified directly in a config file, then APIYourTrainer doesn't need to do anything other than exist:

from pytorch_metric_learning import trainers
from powerful_benchmarker import api_parsers

class YourTrainer(trainers.MetricLossOnly):
    def __init__(self, foo, **kwargs):
        super().__init__(**kwargs)
        self.foo = foo
        print("foo = ", self.foo)

class APIYourTrainer(api_parsers.BaseAPIParser):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

r = runner(**(args.__dict__))
r.register("trainer", YourTrainer)
r.register("api_parser", APIYourTrainer)
r.run()

However, if foo is more complex, e.g. it is an object that requires some logic to be created, then you'll want APIYourTrainer to handle that logic, and then add foo to the default_kwargs_trainer dictionary. Check out the code documentation for details on this.