Skip to content

sufficient_accuracy

SufficientAccuracy

Bases: torch.nn.Module

Determines if a batch of logits has accuracy greater than some threshold. This can be used to control program flow.

Example:

condition_fn = SufficientAccuracy(threshold=0.7)
if condition_fn(logits, labels):
    ...

Source code in pytorch_adapt\layers\sufficient_accuracy.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class SufficientAccuracy(torch.nn.Module):
    """
    Determines if a batch of logits has accuracy greater
    than some threshold. This can be used to control
    program flow.

    Example:
    ```python
    condition_fn = SufficientAccuracy(threshold=0.7)
    if condition_fn(logits, labels):
        ...
    ```
    """

    def __init__(
        self,
        threshold: float,
        accuracy_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
        to_probs_func: Callable[[torch.Tensor], torch.Tensor] = None,
    ):
        """
        Arguments:
            threshold: The accuracy must be greater than this
                for the forward pass to return True.
            accuracy_func: function that takes in ```(to_probs_func(logits), labels)```
                and returns accuracy. If ```None```, then classification accuracy is used.
            to_probs_func: function that processes the logits before they get passed
                to ```accuracy_func```. If ```None```, then ```torch.nn.Sigmoid``` is used
        """

        super().__init__()
        self.threshold = threshold
        self.accuracy_func = c_f.default(accuracy_func, accuracy)
        self.to_probs_func = c_f.default(to_probs_func, torch.nn.Sigmoid())
        pml_cf.add_to_recordable_attributes(
            self, list_of_names=["accuracy", "threshold"]
        )

    def forward(self, x: torch.Tensor, labels: torch.Tensor) -> bool:
        """
        Arguments:
            x: logits to compute accuracy for
            labels: the corresponding labels
        Returns:
            ```True``` if the accuracy is greater than ```self.threshold```
        """
        with torch.no_grad():
            x = self.to_probs_func(x)
            labels = labels.type(torch.int)
            self.accuracy = self.accuracy_func(x, labels).item()
        return self.accuracy > self.threshold

    def extra_repr(self):
        """"""
        return c_f.extra_repr(self, ["threshold"])

__init__(threshold, accuracy_func=None, to_probs_func=None)

Parameters:

Name Type Description Default
threshold float

The accuracy must be greater than this for the forward pass to return True.

required
accuracy_func Callable[[torch.Tensor, torch.Tensor], torch.Tensor]

function that takes in (to_probs_func(logits), labels) and returns accuracy. If None, then classification accuracy is used.

None
to_probs_func Callable[[torch.Tensor], torch.Tensor]

function that processes the logits before they get passed to accuracy_func. If None, then torch.nn.Sigmoid is used

None
Source code in pytorch_adapt\layers\sufficient_accuracy.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def __init__(
    self,
    threshold: float,
    accuracy_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
    to_probs_func: Callable[[torch.Tensor], torch.Tensor] = None,
):
    """
    Arguments:
        threshold: The accuracy must be greater than this
            for the forward pass to return True.
        accuracy_func: function that takes in ```(to_probs_func(logits), labels)```
            and returns accuracy. If ```None```, then classification accuracy is used.
        to_probs_func: function that processes the logits before they get passed
            to ```accuracy_func```. If ```None```, then ```torch.nn.Sigmoid``` is used
    """

    super().__init__()
    self.threshold = threshold
    self.accuracy_func = c_f.default(accuracy_func, accuracy)
    self.to_probs_func = c_f.default(to_probs_func, torch.nn.Sigmoid())
    pml_cf.add_to_recordable_attributes(
        self, list_of_names=["accuracy", "threshold"]
    )

forward(x, labels)

Parameters:

Name Type Description Default
x torch.Tensor

logits to compute accuracy for

required
labels torch.Tensor

the corresponding labels

required

Returns:

Type Description
bool

True if the accuracy is greater than self.threshold

Source code in pytorch_adapt\layers\sufficient_accuracy.py
48
49
50
51
52
53
54
55
56
57
58
59
60
def forward(self, x: torch.Tensor, labels: torch.Tensor) -> bool:
    """
    Arguments:
        x: logits to compute accuracy for
        labels: the corresponding labels
    Returns:
        ```True``` if the accuracy is greater than ```self.threshold```
    """
    with torch.no_grad():
        x = self.to_probs_func(x)
        labels = labels.type(torch.int)
        self.accuracy = self.accuracy_func(x, labels).item()
    return self.accuracy > self.threshold