Skip to content

Luoyadan/SF-PGL

Repository files navigation

SF-PGL

This work is the official Pytorch implementation of our papers:

Source-Free Progressive Graph Learning for Open-Set Domain Adaptation
Yadan Luo, Zijian Wang, Zhuoxiao Chen, Zi Huang, Mahsa Baktashmotlagh
Transcations on Pattern Analysis and Machine Intelligence (TPAMI)

Progressive Graph Learning for Open-Set Domain Adaptation
Yadan Luo^, Zijian Wang^, Zi Huang, Mahsa Baktashmotlagh
International Conference on Machine Learning (ICML) 2020
[Paper] [Code]


Framework

To further handle a more realistic yet challenging source-free setting, a novel SF-PGL framework was proposed, which leverages a balanced pseudo-labeling regime to enable uncertainty-aware progressive learning without relying on any distribution matching or adversarial learning methods. As an extension of PGL, we have significantly extended the idea of open-set domain adaptation from the unsupervised learning case to the source-free and semi-supervised settings, from image classification to action recognition, where the complex data interaction and more significant domain gap are addressed. We further discussed a hitherto untouched aspect of OSDA model - the model calibration issue. Experimental results evidenced that the SF-PGL can alleviate the class imbalance introduced by pseudo-labeled sets so that the overconfidence and under-confidence of the OSDA model can be avoided.


Contents

Requirements

  • Python 3.6
  • Pytorch 1.3

Dataset Preparation

  • Office-home
  • VisDA-17
  • Syn2Real-O (VisDA-18)

Training

The general command for training is,

python3 train.py

Change arguments for different experiments:

  • dataset: "home" / "visda" / "visda18"
  • batch_size: mini_batch size
  • beta: The ratio of known target sample and Unk target sample in the pseudo label set
  • EF : Enlarging Factor α
  • num_layers: GNN's depth
  • adv_coeff: adversarial loss coefficient γ
  • node_loss: node classification loss μ For the detailed hyper-parameters setting for each dataset, please refer to Section 5.2 and Appendix 3.

Remember to change dataset_root to suit your own case

The training loss and validation accuracy will be automatically saved in './logs/', which can be visualized with tensorboard. The model weights will be saved in './checkpoints'

About

This work is the official Pytorch implementation of our papers: Source-Free Progressive Graph Learning for Open-Set Domain Adaptation (TPAMI))

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published