Skip to content

inference

adabn_fn(x, domain, models, **kwargs)

AdaBN features and logits.

Parameters:

Name Type Description Default
x

The input to the model

required
domain

0 for source domain, 1 for target domain.

required
Source code in pytorch_adapt\inference\inference.py
18
19
20
21
22
23
24
25
26
27
28
def adabn_fn(x, domain, models, **kwargs) -> Dict[str, torch.Tensor]:
    """
    [AdaBN][pytorch_adapt.adapters.AdaBN] features and logits.
    Arguments:
        x: The input to the model
        domain: 0 for source domain, 1 for target domain.
    """
    domain = check_domain(domain, keep_len=True)
    features = models["G"](x, domain)
    logits = models["C"](features, domain)
    return {"features": features, "logits": logits}

adda_fn(x, domain, models, get_all=False, **kwargs)

ADDA features and logits.

Parameters:

Name Type Description Default
x

The input to the model

required
domain

If 0, then features = G(x). Otherwise features = T(x).

required
models

Dictionary of models with keys ["G", "C", "T"].

required
get_all bool

If True, then return features and logits using both G and T as the feature extractor.

False

Returns:

Type Description
Dict[str, torch.Tensor]

A dictionary of features and logits.

  • If get_all is False, then the keys are {"features", "logits"}.

  • If get_all is True, then the keys will be {"features", "logits", "other_features", "other_logits"}, where the other_ prefix represents the features and logits obtained using G if domain == 1 and T if domain == 0.

Source code in pytorch_adapt\inference\inference.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def adda_fn(
    x, domain, models, get_all: bool = False, **kwargs
) -> Dict[str, torch.Tensor]:
    """
    [ADDA][pytorch_adapt.adapters.ADDA] features and logits.
    Arguments:
        x: The input to the model
        domain: If 0, then ```features = G(x)```.
            Otherwise ```features = T(x)```.
        models: Dictionary of models with keys
            ```["G", "C", "T"]```.
        get_all: If ```True```, then return features
            and logits using both ```G``` and ```T```
            as the feature extractor.
    Returns:
        A dictionary of features and logits.

            - If ```get_all``` is ```False```, then the keys are ```{"features", "logits"}```.

            - If ```get_all``` is ```True```,
            then the keys will be ```{"features", "logits", "other_features", "other_logits"}```,
            where the ```other_``` prefix represents the features and logits obtained
            using ```G``` if ```domain == 1``` and ```T``` if ```domain == 0```.
    """
    domain = check_domain(domain)
    fe = "G" if domain == 0 else "T"
    features = models[fe](x)
    logits = models["C"](features)
    output = {"features": features, "logits": logits}
    if get_all:
        fe = "T" if fe == "G" else "G"
        features = models[fe](x)
        logits = models["C"](features)
        output.update({"other_features": features, "other_logits": logits})
    return output

adda_full_fn(x, **kwargs)

ADDA features, logits, discriminator logits, other features, other logits, other discriminator logits. See adda_fn for the input arguments.

Returns:

Type Description
Dict[str, torch.Tensor]

discriminator logits ("d_logits"), "other" discriminator logits ("other_d_logits") in addition to everything returned by adda_fn with get_all = True.

Source code in pytorch_adapt\inference\inference.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def adda_full_fn(x, **kwargs) -> Dict[str, torch.Tensor]:
    """
    [ADDA][pytorch_adapt.adapters.ADDA] features, logits, discriminator logits,
    other features, other logits, other discriminator logits.
    See [adda_fn][pytorch_adapt.inference.adda_fn] for the input arguments.
    Returns:
        discriminator logits (```"d_logits"```), "other" discriminator logits (```"other_d_logits"```)
            in addition to everything returned by [adda_fn][pytorch_adapt.inference.adda_fn]
            with ```get_all = True```.
    """
    layer = kwargs.get("layer", "features")
    output = with_d(x=x, fn=adda_fn, get_all=True, **kwargs)
    output2 = d_fn(x=output[f"other_{layer}"], **kwargs)
    output["other_d_logits"] = output2["d_logits"]
    return output

adda_with_d(**kwargs)

ADDA features, logits, and discriminator logits. See adda_fn for the input arguments.

Returns:

Type Description
Dict[str, torch.Tensor]

discriminator logits as "d_logits", in addition to everything returned by adda_fn.

Source code in pytorch_adapt\inference\inference.py
68
69
70
71
72
73
74
75
76
def adda_with_d(**kwargs) -> Dict[str, torch.Tensor]:
    """
    [ADDA][pytorch_adapt.adapters.ADDA] features, logits, and discriminator logits. See
    [adda_fn][pytorch_adapt.inference.adda_fn] for the input arguments.
    Returns:
        discriminator logits as ```"d_logits"```, in addition to
            everything returned by [adda_fn][pytorch_adapt.inference.adda_fn].
    """
    return with_d(fn=adda_fn, **kwargs)

default_fn(x, models, **kwargs)

The default inference function for BaseAdapter.

Source code in pytorch_adapt\inference\inference.py
 8
 9
10
11
12
13
14
def default_fn(x, models, **kwargs) -> Dict[str, torch.Tensor]:
    """
    The default inference function for [BaseAdapter][pytorch_adapt.adapters.BaseAdapter].
    """
    features = models["G"](x)
    logits = models["C"](features)
    return {"features": features, "logits": logits}

mcd_fn(x, models, get_all=False, **kwargs)

Returns:

Type Description

Features and logits, where logits = sum(C(features)).

Source code in pytorch_adapt\inference\inference.py
140
141
142
143
144
145
146
147
148
149
150
151
152
def mcd_fn(x, models, get_all=False, **kwargs):
    """
    Returns:
        Features and logits, where ```logits = sum(C(features))```.
    """
    features = models["G"](x)
    logits_list = models["C"](features)
    logits = sum(logits_list)
    output = {"features": features, "logits": logits}
    if get_all:
        for i, L in enumerate(logits_list):
            output[f"logits{i}"] = L
    return output

rtn_fn(x, domain, models, get_all=False, **kwargs)

RTN features and logits.

Parameters:

Name Type Description Default
x

The input to the model

required
domain

If 0, logits = residual_model(C(G(x))). Otherwise, logits = C(G(x)).

required
models

Dictionary of models with keys ["G", "C", "residual_model"].

required
get_all

If True, then in addition to the regular outputs, it will return the residual_model logits when domain == 1 and the C logits when domain == 0.

False

Returns:

Type Description
Dict[str, torch.Tensor]

A dictionary of features and logits.

  • If get_all is False, then the keys are {"features", "logits"}.

  • If get_all is True, then the keys will be {"features", "logits", "other_logits"}.

Source code in pytorch_adapt\inference\inference.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def rtn_fn(x, domain, models, get_all=False, **kwargs) -> Dict[str, torch.Tensor]:
    """
    [RTN][pytorch_adapt.adapters.RTN] features and logits.
    Arguments:
        x: The input to the model
        domain: If 0, ```logits = residual_model(C(G(x)))```.
            Otherwise, ```logits = C(G(x))```.
        models: Dictionary of models with keys
            ```["G", "C", "residual_model"]```.
        get_all: If ```True```, then in addition to the regular outputs,
            it will return the ```residual_model``` logits when
            ```domain == 1``` and the ```C``` logits when ```domain == 0```.
    Returns:
        A dictionary of features and logits.

            - If ```get_all``` is ```False```, then the keys are ```{"features", "logits"}```.

            - If ```get_all``` is ```True```,
            then the keys will be ```{"features", "logits", "other_logits"}```.
    """
    domain = check_domain(domain)
    f_dict = default_fn(x=x, models=models)
    target_logits = f_dict["logits"]
    if get_all or domain == 0:
        src_logits = models["residual_model"](target_logits)
    if domain == 0:
        f_dict["logits"] = src_logits
        if get_all:
            f_dict["other_logits"] = target_logits
    elif get_all and domain == 1:
        f_dict["other_logits"] = src_logits
    return f_dict

symnets_fn(x, domain, models, get_all=False, **kwargs)

Parameters:

Name Type Description Default
x

The input to the model

required
domain

0 for the source domain, 1 for the target domain.

required

Returns:

Type Description

Features and logits, where logits = C(features)[domain].

Source code in pytorch_adapt\inference\inference.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def symnets_fn(x, domain, models, get_all=False, **kwargs):
    """
    Arguments:
        x: The input to the model
        domain: 0 for the source domain, 1 for the target domain.
    Returns:
        Features and logits, where ```logits = C(features)[domain]```.
    """
    domain = check_domain(domain)
    features = models["G"](x)
    logits = models["C"](features)[domain]
    output = {"features": features, "logits": logits}
    if get_all:
        logits = models["C"](features)[int(not domain)]
        output.update({"other_logits": logits})
    return output