Bases: torch.nn.Module
Determines if a batch of logits has accuracy greater
than some threshold. This can be used to control
program flow.
Example:
condition_fn = SufficientAccuracy(threshold=0.7)
if condition_fn(logits, labels):
...
Source code in pytorch_adapt\layers\sufficient_accuracy.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64 | class SufficientAccuracy(torch.nn.Module):
"""
Determines if a batch of logits has accuracy greater
than some threshold. This can be used to control
program flow.
Example:
```python
condition_fn = SufficientAccuracy(threshold=0.7)
if condition_fn(logits, labels):
...
```
"""
def __init__(
self,
threshold: float,
accuracy_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
to_probs_func: Callable[[torch.Tensor], torch.Tensor] = None,
):
"""
Arguments:
threshold: The accuracy must be greater than this
for the forward pass to return True.
accuracy_func: function that takes in ```(to_probs_func(logits), labels)```
and returns accuracy. If ```None```, then classification accuracy is used.
to_probs_func: function that processes the logits before they get passed
to ```accuracy_func```. If ```None```, then ```torch.nn.Sigmoid``` is used
"""
super().__init__()
self.threshold = threshold
self.accuracy_func = c_f.default(accuracy_func, accuracy)
self.to_probs_func = c_f.default(to_probs_func, torch.nn.Sigmoid())
pml_cf.add_to_recordable_attributes(
self, list_of_names=["accuracy", "threshold"]
)
def forward(self, x: torch.Tensor, labels: torch.Tensor) -> bool:
"""
Arguments:
x: logits to compute accuracy for
labels: the corresponding labels
Returns:
```True``` if the accuracy is greater than ```self.threshold```
"""
with torch.no_grad():
x = self.to_probs_func(x)
labels = labels.type(torch.int)
self.accuracy = self.accuracy_func(x, labels).item()
return self.accuracy > self.threshold
def extra_repr(self):
""""""
return c_f.extra_repr(self, ["threshold"])
|
__init__(threshold, accuracy_func=None, to_probs_func=None)
Parameters:
Name |
Type |
Description |
Default |
threshold |
float
|
The accuracy must be greater than this
for the forward pass to return True. |
required
|
accuracy_func |
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
function that takes in (to_probs_func(logits), labels)
and returns accuracy. If None , then classification accuracy is used. |
None
|
to_probs_func |
Callable[[torch.Tensor], torch.Tensor]
|
function that processes the logits before they get passed
to accuracy_func . If None , then torch.nn.Sigmoid is used |
None
|
Source code in pytorch_adapt\layers\sufficient_accuracy.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46 | def __init__(
self,
threshold: float,
accuracy_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
to_probs_func: Callable[[torch.Tensor], torch.Tensor] = None,
):
"""
Arguments:
threshold: The accuracy must be greater than this
for the forward pass to return True.
accuracy_func: function that takes in ```(to_probs_func(logits), labels)```
and returns accuracy. If ```None```, then classification accuracy is used.
to_probs_func: function that processes the logits before they get passed
to ```accuracy_func```. If ```None```, then ```torch.nn.Sigmoid``` is used
"""
super().__init__()
self.threshold = threshold
self.accuracy_func = c_f.default(accuracy_func, accuracy)
self.to_probs_func = c_f.default(to_probs_func, torch.nn.Sigmoid())
pml_cf.add_to_recordable_attributes(
self, list_of_names=["accuracy", "threshold"]
)
|
forward(x, labels)
Parameters:
Name |
Type |
Description |
Default |
x |
torch.Tensor
|
logits to compute accuracy for |
required
|
labels |
torch.Tensor
|
the corresponding labels |
required
|
Returns:
Type |
Description |
bool
|
True if the accuracy is greater than self.threshold
|
Source code in pytorch_adapt\layers\sufficient_accuracy.py
48
49
50
51
52
53
54
55
56
57
58
59
60 | def forward(self, x: torch.Tensor, labels: torch.Tensor) -> bool:
"""
Arguments:
x: logits to compute accuracy for
labels: the corresponding labels
Returns:
```True``` if the accuracy is greater than ```self.threshold```
"""
with torch.no_grad():
x = self.to_probs_func(x)
labels = labels.type(torch.int)
self.accuracy = self.accuracy_func(x, labels).item()
return self.accuracy > self.threshold
|