-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.py
61 lines (53 loc) · 1.48 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import numpy as np
from datasets import load_metric
from transformers import ViTForImageClassification, Trainer, ViTImageProcessor
from args import training_args
metric = load_metric("accuracy")
def_model_path = "google/vit-large-patch16-224"
feature_extractor = ViTImageProcessor(def_model_path)
num_labels = 2
# for preprocess
def transform(example_batch):
"""
dynamic transform function for batches
"""
inputs = feature_extractor(
[x for x in example_batch["image"]], return_tensors="pt"
)
inputs['id'] = example_batch['id']
inputs["labels"] = example_batch["label"]
return inputs
# for train
def collate_fn(batch):
"""
data collator function
"""
return {
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
'labels': torch.tensor([x['labels'] for x in batch])
}
def compute_metrics(p):
"""
metric function called training / eval
"""
return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
def load_model(dataset, model_path=def_model_path):
"""
return trainer for fine-tuning model
"""
model = ViTForImageClassification.from_pretrained(
model_path,
num_labels=num_labels,
ignore_mismatched_sizes=True
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
compute_metrics=compute_metrics,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=feature_extractor,
)
return trainer