Tackling image classification, a core aspect of Computer Vision, is the focus of this repository. Utilizing PyTorch, a popular framework, this project embraces transfer learning. This approach not only saves time and resources but often yields superior results compared to building and training a neural network from scratch. The repository features image classification solutions using various algorithms within the PyTorch ecosystem:
- EfficientNet
- ResNet
- VGG
- GoogLeNet
Before training, update the configuration file:
- Loss Function:
CrossEntropyLoss
is recommended for binary and multi-class classification. Choose betweenCrossEntropyLoss
andNLLLoss
. - Optimization Function: Options include Adam,
RAdam
,SGD
,Adadelta
,Adagrad
,AdamW
,Adamax
,ASGD
,NAdam
, andRprop
, withAdam
being recommended. - MODEL_NAME: Options are
efficientnetB0
toefficientnetB7
for Efficientnet,resnet18
toresnet152
for Resnet,vgg11
tovgg19bn
for VGG, andgooglenet
. - SAVE_WEIGHT_PATH: Directory to save model weights.
- DATA_DIR: Directory containing the dataset.
- CHECKPOINT: Directory for pretrained models.
- NUMCLASS: Number of classes.
Adjust other hyperparameters like EPOCHS, BATCHSIZE, and LEARNING_RATE as needed. To train:
cd ./src && python3 train.py
Ensure that the model name, checkpoint, and number of classes in the config file match those used during training:
cd ./src && python predict.py \
--test_path ../test_img \
--batch_predict 16
- --test_path: Path to public test images (file or directory).
- --batch_predict: Batch size for prediction.
Results will be available in predict.csv
.