-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
142 lines (113 loc) · 4.24 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import numpy as np
from tensorflow import keras
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from PIL import Image
from io import BytesIO
import faiss
import json
import copy
import time
import aiohttp
from aiohttp.client import ClientSession
import asyncio
import nest_asyncio
nest_asyncio.apply()
class Sorter:
'''
Keras-Faiss section, to sort the incoming images according to their MobileNetV3Large similarity to the target image
'''
def __init__(self) -> None:
self.model = None
self.index = None
self.candidates = None
self.length = None
self.loop = asyncio.new_event_loop()
def init_model(self) -> None:
# vgg16_model = keras.applications.vgg16.VGG16(
# weights='imagenet', include_top=True)
# model = keras.Sequential()
# # Remove the last softmax layer. Only use VGG16 to extract feature vector
# for layer in vgg16_model.layers[:-1]:
# model.add(layer)
MobileNetV3Large_model = keras.applications.MobileNetV3Large(
include_top=True,
weights="imagenet",
)
# Remove the last softmax layer. Only use MobileNetV3Large to extract feature vector
x = MobileNetV3Large_model.layers[-2].output
model = keras.Model(inputs=MobileNetV3Large_model.input, outputs=x)
# Freeze the layers
for layer in model.layers:
layer.trainable = False
print(model.summary())
self.model = model
def faiss_index(self, candidates) -> None:
dimension = 1000
index = faiss.IndexFlatL2(dimension)
self.candidates = copy.deepcopy(candidates)
self.length = len(candidates)
candidateVectors = np.empty([self.length, 1000])
async def vectorize_remote_image(index, session: ClientSession):
async with session.get(url=candidates[index]["image"]) as response:
res = await response.read()
img = Image.open(BytesIO(res)).resize((224, 224))
x = keras.preprocessing.image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = keras.applications.mobilenet_v3.preprocess_input(x)
features = self.model.predict(x)
candidateVectors[index] = features
async def batch_requests():
async with aiohttp.ClientSession() as session:
tasks = [vectorize_remote_image(
index=i, session=session) for i in range(self.length)]
# the await must be nest inside of the session
await asyncio.gather(*tasks, return_exceptions=True)
self.loop.run_until_complete(batch_requests())
index.add(candidateVectors)
self.index = index
def faiss_search(self, target: bytes) -> None:
targetVector = np.empty([1, 1000])
img = Image.open(BytesIO(target)).resize((224, 224))
x = keras.preprocessing.image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = keras.applications.vgg16.preprocess_input(x)
features = self.model.predict(x)
targetVector[0] = features
D, I = self.index.search(targetVector, self.length)
# drop faiss index
self.index = None
results = []
for item in I[0]:
results.append(self.candidates[int(item)])
# drop candidates
self.candidates = None
self.length = None
return results
sorter = Sorter()
sorter.init_model()
'''
FastAPI section, to provide http service
'''
app = FastAPI()
@app.post("/sort")
async def sort(
candidates: str = Form(),
target: UploadFile = File()
):
try:
start_time = time.time()
candidates = json.loads(candidates)
if len(candidates["candidates"]) <= 1:
return {
"result": candidates["candidates"]
}
sorter.faiss_index(candidates["candidates"])
target_content = await target.read()
result = sorter.faiss_search(target_content)
print(f"time cost: {time.time() - start_time}s")
return {
"result": result,
}
except Exception as e:
print(f"Error: \n{e}")
raise HTTPException(status_code=500, detail=f"Error: \n{e}")