neighborhood_aggregation
NeighborhoodAggregation
¶
Bases: torch.nn.Module
Implementation of the pseudo labeling step in Domain Adaptation with Auxiliary Target Domain-Oriented Classifier.
Source code in pytorch_adapt\layers\neighborhood_aggregation.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
|
__init__(dataset_size, feature_dim, num_classes, k=5, T=0.5)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_size |
int
|
The number of samples in the target dataset. |
required |
feature_dim |
int
|
The feature dimensionality, i.e at each iteration
the features should be size |
required |
num_classes |
int
|
The number of class labels in the target dataset. |
required |
k |
int
|
The number of nearest neighbors used to determine each sample's pseudolabel |
5
|
T |
float
|
The softmax temperature used when storing predictions in memory. |
0.5
|
Source code in pytorch_adapt\layers\neighborhood_aggregation.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
|
forward(features, logits=None, update=False, idx=None)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
features |
torch.Tensor
|
The features to compute pseudolabels for. |
required |
logits |
torch.Tensor
|
The logits from which predictions will be computed and
stored in memory. Required if |
None
|
update |
bool
|
If True, the current batch of predictions is added to the memory bank. |
False
|
idx |
torch.Tensor
|
A tensor containing the dataset indices that
produced each row of |
None
|
Source code in pytorch_adapt\layers\neighborhood_aggregation.py
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
|