atdoc
ATDOCHook
¶
Bases: BaseHook
Creates pseudo labels for the target domain using k-nearest neighbors. Then computes a classification loss based on these pseudo labels.
Implementation of Domain Adaptation with Auxiliary Target Domain-Oriented Classifier.
Source code in pytorch_adapt\hooks\atdoc.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
|
__init__(dataset_size, feature_dim, num_classes, k=5, loss_fn=None, **kwargs)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_size |
The number of samples in the target dataset. |
required | |
feature_dim |
The feature dimensionality, i.e at each iteration
the features should be size |
required | |
num_classes |
The number of class labels in the target dataset. |
required | |
k |
The number of nearest neighbors used to determine each sample's pseudolabel |
5
|
|
loss_fn |
The classification loss function.
If |
None
|
Source code in pytorch_adapt\hooks\atdoc.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
|