Bases: BaseConditionHook
Returns True
if the discriminator's
accuracy is higher than some threshold.
Source code in pytorch_adapt\hooks\conditions.py
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44 | class StrongDHook(BaseConditionHook):
"""
Returns ```True``` if the discriminator's
accuracy is higher than some threshold.
"""
def __init__(self, threshold: float = 0.6, **kwargs):
"""
Arguments:
threshold: The discriminator's accuracy must be higher
than this threshold for the hook to return ```True```.
"""
super().__init__(**kwargs)
self.accuracy_fn = SufficientAccuracy(
threshold=threshold, to_probs_func=torch.nn.Sigmoid()
)
self.hook = FeaturesChainHook(
FeaturesHook(detach=True), DLogitsHook(detach=True)
)
def call(self, inputs, losses):
""""""
with torch.no_grad():
outputs = self.hook(inputs, losses)[0]
[d_src_logits, d_target_logits] = c_f.extract(
[outputs, inputs],
c_f.filter(
self.hook.out_keys, "_dlogits_detached$", ["^src", "^target"]
),
)
[src_domain, target_domain] = c_f.extract(
inputs, ["src_domain", "target_domain"]
)
dlogits = torch.cat([d_src_logits, d_target_logits], dim=0)
domain_labels = torch.cat([src_domain, target_domain], dim=0)
return self.accuracy_fn(dlogits, domain_labels)
|
__init__(threshold=0.6, **kwargs)
Parameters:
Name |
Type |
Description |
Default |
threshold |
float
|
The discriminator's accuracy must be higher
than this threshold for the hook to return True . |
0.6
|
Source code in pytorch_adapt\hooks\conditions.py
15
16
17
18
19
20
21
22
23
24
25
26
27 | def __init__(self, threshold: float = 0.6, **kwargs):
"""
Arguments:
threshold: The discriminator's accuracy must be higher
than this threshold for the hook to return ```True```.
"""
super().__init__(**kwargs)
self.accuracy_fn = SufficientAccuracy(
threshold=threshold, to_probs_func=torch.nn.Sigmoid()
)
self.hook = FeaturesChainHook(
FeaturesHook(detach=True), DLogitsHook(detach=True)
)
|