- Linux
- Python 3.9
- PyTorch 1.10.0+cu111
Clone this repo.
git clone https://github.com/THUDM/tot-prediction.git
cd tot-prediction
Please install dependencies by
pip install -r requirements.txt
The dataset can be downloaded from BaiduPan with password f62u or Aliyun. Please put the data folder into the project directory.
python process.py
python citation_only.py # Use citation number only for prediction
python regressor.py # Random Forest (RF) and GBRT
python pagerank.py # PageRank
python gnn.py # GraphSAGE
Evaluation metrics: average MAP
MAP | |
---|---|
Citation | 0.6413 |
RF | 0.5409 |
GBRT | 0.5725 |
PageRank | 0.6504 |
GraphSAGE | 0.0811 |
cd RGTN-NIE
- Python 3.10
- PyTorch 2.1
- dgl 2.1.0+cu118
modify save-path
in train_geni.sh
and train_two.sh
to save the model.
- run
sh train_geni.sh
for GENI in tot (full batch training) - run
sh train_two.sh
for RGTN in tot (full batch training)
modify model_path
in inference.sh
and inference_two.sh
to load the model.
modify output_dir
in inference.sh
and inference_two.sh
to save the prediction results.
- run
sh inference.sh
for GENI in tot (full batch inference) - run
sh inference_two.sh
for RGTN in tot (full batch inference)
python pagerank_nie.py