base_weighter
BaseWeighter
¶
Multiplies losses by scalar values, and then reduces them to a single value.
Source code in pytorch_adapt\weighters\base_weighter.py
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
|
__call__(loss_dict)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_dict |
Dict[str, torch.Tensor]
|
A mapping from loss names to loss values. |
required |
Returns:
Type | Description |
---|---|
Tuple[torch.Tensor, Dict[str, float]]
|
A tuple consisting of
|
Source code in pytorch_adapt\weighters\base_weighter.py
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
|
__init__(reduction, weights=None, scale=1)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
reduction |
Callable[[List[torch.Tensor]], torch.Tensor]
|
A function that takes in a list of losses and returns a single loss value. |
required |
weights |
Dict[str, float]
|
A mapping from loss names to weight values. If |
None
|
scale |
float
|
A scalar that every loss gets multiplied by. |
1
|
Source code in pytorch_adapt\weighters\base_weighter.py
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
|