Skip to content

reducers

BaseReducer

Bases: BaseHook, ABC

Converts an unreduced loss tensor into a single number. In other words, if the loss tensor has shape (N,), the reducer converts it to shape (1,).

Source code in pytorch_adapt\hooks\reducers.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class BaseReducer(BaseHook, ABC):
    """
    Converts an unreduced loss tensor into a single number.
    In other words, if the loss tensor has shape ```(N,)```,
    the reducer converts it to shape ```(1,)```.
    """

    def __init__(
        self,
        apply_to: List[str] = None,
        default_reducer: "BaseReducer" = None,
        **kwargs,
    ):
        """
        Arguments:
            apply_to: list of loss names to apply reduction to
            default_reducer: a reducer to use for losses that
                are not already reduced and are also not
                specified in ```apply_to```. If ```None```,
                then no action is taken.
        """
        super().__init__(**kwargs)
        self.apply_to = apply_to
        self.default_reducer = default_reducer
        self.curr_loss_keys = []

    def call(self, inputs, losses):
        """"""
        self.curr_loss_keys = list(losses.keys())
        apply_to = self.get_keys_to_apply_to(losses)
        outputs, losses = self.call_reducer(inputs, losses, apply_to)
        if self.default_reducer:
            combined = c_f.assert_dicts_are_disjoint(inputs, outputs)
            new_outputs, losses = self.default_reducer(combined, losses)
            outputs.update(new_outputs)
        if losses.keys() != set(self.curr_loss_keys):
            raise ValueError(
                "Loss dict returned by reducer should have same keys as input loss dict"
            )
        return outputs, losses

    @abstractmethod
    def call_reducer(self, inputs, losses, apply_to):
        pass

    def _loss_keys(self):
        """"""
        return self.curr_loss_keys

    def get_keys_to_apply_to(self, losses):
        apply_to = self.apply_to
        if apply_to is None:
            apply_to = [k for k, v in losses.items() if not c_f.len_one_tensor(v)]
        elif len(set(apply_to) - set(self.curr_loss_keys)) > 0:
            raise ValueError(
                f"self.apply_to ({self.apply_to}) must be a subset of losses.keys() ({losses.keys()})"
            )
        return apply_to

    def extra_repr(self):
        return c_f.extra_repr(self, ["apply_to"])

__init__(apply_to=None, default_reducer=None, **kwargs)

Parameters:

Name Type Description Default
apply_to List[str]

list of loss names to apply reduction to

None
default_reducer 'BaseReducer'

a reducer to use for losses that are not already reduced and are also not specified in apply_to. If None, then no action is taken.

None
Source code in pytorch_adapt\hooks\reducers.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def __init__(
    self,
    apply_to: List[str] = None,
    default_reducer: "BaseReducer" = None,
    **kwargs,
):
    """
    Arguments:
        apply_to: list of loss names to apply reduction to
        default_reducer: a reducer to use for losses that
            are not already reduced and are also not
            specified in ```apply_to```. If ```None```,
            then no action is taken.
    """
    super().__init__(**kwargs)
    self.apply_to = apply_to
    self.default_reducer = default_reducer
    self.curr_loss_keys = []

EntropyReducer

Bases: BaseReducer

Implementation of "entropy conditioning" from Conditional Adversarial Domain Adaptation. It weights loss elements using EntropyWeights. The entropy weights are derived from classifier logits.

Source code in pytorch_adapt\hooks\reducers.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class EntropyReducer(BaseReducer):
    """
    Implementation of "entropy conditioning" from
    [Conditional Adversarial Domain Adaptation](https://arxiv.org/abs/1705.10667).
    It weights loss elements using
    [```EntropyWeights```][pytorch_adapt.layers.EntropyWeights].
    The entropy weights are derived from classifier logits.
    """

    def __init__(
        self,
        f_hook: BaseHook = None,
        domains: List[str] = None,
        entropy_weights_fn: Callable[[torch.Tensor], torch.Tensor] = None,
        detach_weights: bool = True,
        **kwargs,
    ):
        """
        Arguments:
            f_hook: the hook for computing logits from
                which entropy weights are derived
            domains: the domains that ```f_hook``` should compute for
            entropy_weights_fn: the function for computing the weights
                that will be multiplied with the unreduced losses.
            detach_weights: If ```True```, the entropy weights are
                detached from the autograd graph
        """
        super().__init__(**kwargs)
        src_regex = "^{0}_|_{0}$|_{0}_|^{0}$".format("src")
        target_regex = "^{0}_|_{0}$|_{0}_|^{0}$".format("target")
        self.src_regex = re.compile(src_regex)
        self.target_regex = re.compile(target_regex)
        self.entropy_weights_fn = c_f.default(entropy_weights_fn, EntropyWeights, {})
        self.f_hook = c_f.default(
            f_hook,
            FeaturesAndLogitsHook,
            {
                "detach_features": detach_weights,
                "detach_logits": detach_weights,
                "domains": domains,
            },
        )
        self.context = torch.no_grad() if detach_weights else nullcontext()

    def call_reducer(self, inputs, losses, apply_to):
        outputs = self.f_hook(inputs, losses)[0]
        for k in apply_to:
            if self.src_regex.search(k):
                domain = "src"
            elif self.target_regex.search(k):
                domain = "target"
            else:
                raise ValueError
            with self.context:
                search_str = c_f.filter(self.f_hook.out_keys, "_logits", [f"^{domain}"])
                [logits] = c_f.extract([outputs, inputs], search_str)
                weights = self.entropy_weights_fn(logits)
            losses[k] = torch.mean(weights * losses[k])

        return outputs, losses

    def _out_keys(self):
        """"""
        return self.f_hook.out_keys

__init__(f_hook=None, domains=None, entropy_weights_fn=None, detach_weights=True, **kwargs)

Parameters:

Name Type Description Default
f_hook BaseHook

the hook for computing logits from which entropy weights are derived

None
domains List[str]

the domains that f_hook should compute for

None
entropy_weights_fn Callable[[torch.Tensor], torch.Tensor]

the function for computing the weights that will be multiplied with the unreduced losses.

None
detach_weights bool

If True, the entropy weights are detached from the autograd graph

True
Source code in pytorch_adapt\hooks\reducers.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def __init__(
    self,
    f_hook: BaseHook = None,
    domains: List[str] = None,
    entropy_weights_fn: Callable[[torch.Tensor], torch.Tensor] = None,
    detach_weights: bool = True,
    **kwargs,
):
    """
    Arguments:
        f_hook: the hook for computing logits from
            which entropy weights are derived
        domains: the domains that ```f_hook``` should compute for
        entropy_weights_fn: the function for computing the weights
            that will be multiplied with the unreduced losses.
        detach_weights: If ```True```, the entropy weights are
            detached from the autograd graph
    """
    super().__init__(**kwargs)
    src_regex = "^{0}_|_{0}$|_{0}_|^{0}$".format("src")
    target_regex = "^{0}_|_{0}$|_{0}_|^{0}$".format("target")
    self.src_regex = re.compile(src_regex)
    self.target_regex = re.compile(target_regex)
    self.entropy_weights_fn = c_f.default(entropy_weights_fn, EntropyWeights, {})
    self.f_hook = c_f.default(
        f_hook,
        FeaturesAndLogitsHook,
        {
            "detach_features": detach_weights,
            "detach_logits": detach_weights,
            "domains": domains,
        },
    )
    self.context = torch.no_grad() if detach_weights else nullcontext()

MeanReducer

Bases: BaseReducer

Reduces loss elements by taking the mean.

Source code in pytorch_adapt\hooks\reducers.py
160
161
162
163
164
165
166
167
168
169
170
171
172
class MeanReducer(BaseReducer):
    """
    Reduces loss elements by taking the mean.
    """

    def call_reducer(self, inputs, losses, apply_to):
        for k in apply_to:
            losses[k] = torch.mean(losses[k])
        return {}, losses

    def _out_keys(self):
        """"""
        return []