-
Notifications
You must be signed in to change notification settings - Fork 5
/
main.py
169 lines (127 loc) · 7.11 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import math, numpy as np, requests, os, time, warnings, json
from PIL import Image
from importlib import import_module
import torch
class AstroSleuth():
def __init__(self, tile_size=256, tile_pad=16, wrk_dir="models/", model_name="astrosleuthv2", force_cpu=False, on_download=None, off_download=None):
# Device selection
self.device = "cpu" if force_cpu else ("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
# Check if model name is known
model_src:dict = json.load(open("models.json"))["data"]
assert model_name in model_src, f"Model {model_name} not found! Available models: {list(model_src.keys())}"
# Load model module
module_path = model_src[model_name]["src"]["module"]
self.model_module:torch.nn.Module = getattr(
import_module(module_path.split("/")[0]),
module_path.split("/")[1]
)
# Download model if not available
self.model_pth = os.path.join(wrk_dir, f"{model_name}/model.pth")
self.download(model_src[model_name]["src"]["url"], self.model_pth, model_name, on_download, off_download)
self.wrk_dir = wrk_dir
self.progress = None
# Set tile processing parameters
self.scale = model_src[model_name]["scale"]
self.tile_size = tile_size
self.tile_pad = tile_pad
def download(self, src, dst, model_name, on_download=None, off_download=None):
if not os.path.exists(dst):
assert not src is None, "That model is not available for downloading - Are you on experimental?"
os.makedirs(os.path.dirname(dst), exist_ok=True)
if on_download is not None:
on_download(model_name)
with open(dst, 'wb') as f:
f.write(requests.get(src, allow_redirects=True, headers={"User-Agent":""}).content)
if off_download is not None:
off_download()
def model_inference(self, x: np.ndarray, args:dict={}):
x = torch.from_numpy(x).to(self.device)
if not args is None:
return self.model(x=x, **args).cpu().detach().numpy()
else:
return self.model(x=x).cpu().detach().numpy()
def tile_generator(self, data: np.ndarray, yield_extra_details=False, args={}):
"""
Process data [height, width, channel] into tiles of size [tile_size, tile_size, channel],
feed them one by one into the model, then yield the resulting output tiles.
"""
# [height, width, channel] -> [1, channel, height, width]
data = np.rollaxis(data, 2, 0)
data = np.expand_dims(data, axis=0)
data = np.clip(data, 0, 255)
batch, channel, height, width = data.shape
tiles_x = width // self.tile_size
tiles_y = height // self.tile_size
for i in range(tiles_y * tiles_x):
x = i % tiles_y
y = math.floor(i/tiles_y)
input_start_x = y * self.tile_size
input_start_y = x * self.tile_size
input_end_x = min(input_start_x + self.tile_size, width)
input_end_y = min(input_start_y + self.tile_size, height)
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
input_end_x_pad = min(input_end_x + self.tile_pad, width)
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
input_end_y_pad = min(input_end_y + self.tile_pad, height)
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
input_tile = data[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad].astype(np.float32) / 255
output_tile = self.model_inference(input_tile, args)
self.progress = (i+1) / (tiles_y * tiles_x)
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
output_tile = output_tile[:, :, output_start_y_tile:output_end_y_tile, output_start_x_tile:output_end_x_tile]
output_tile = (np.rollaxis(output_tile, 1, 4).squeeze(0).clip(0,1) * 255).astype(np.uint8)
if yield_extra_details:
yield (output_tile, input_start_x, input_start_y, input_tile_width, input_tile_height, self.progress)
else:
yield output_tile
yield None
def load_model(self):
model:torch.nn.Module = self.model_module().to(self.device)
model.load_state_dict(torch.load(self.model_pth, map_location=torch.device(self.device)), strict=False)
model.eval()
return model
def enhance_with_progress(self, image:Image, args:dict={}):
"""
Take a PIL image and enhance it with the model, yielding stats about the
final image and then the final image itself.
"""
# Load model only now because when using streamlit, multiple users spawn multiple instances of this class, so
# we only load the model when needed. The App() class is responsible for queuing requests to this class
self.model = self.load_model()
original_width, original_height = image.size
# Because tiles may not fit perfectly, we resize to the closest multiple of tile_size
image = image.resize((max(original_width//self.tile_size * self.tile_size, self.tile_size), max(original_height//self.tile_size * self.tile_size, self.tile_size)), resample=Image.Resampling.BICUBIC)
image = np.array(image)
# Initiate a pillow image to save the tiles
result = Image.new("RGB", (image.shape[1]*self.scale, image.shape[0]*self.scale))
for i, tile in enumerate(self.tile_generator(image, yield_extra_details=True, args=args)):
if tile is None:
break
tile_data, x, y, w, h, p = tile
result.paste(Image.fromarray(tile_data), (x*self.scale, y*self.scale))
yield p
# Resize back to the expected size
yield result.resize((original_width * self.scale, original_height * self.scale), resample=Image.Resampling.BICUBIC)
def enhance(self, image:Image) -> Image:
"""
Skips the progress reporting and just returns the final image.
"""
return list(self.enhance_with_progress(image))[-1]
if __name__ == '__main__':
import sys
# User ran with only "main.py"
if not len(sys.argv) == 4:
print("Use main.py with a source, destination file, and model, eg: 'python3 main.py img.png upscaled.png astrosleuthv2'")
print("You might also be interested in using the streamlit interface with: 'streamlit run app.py'")
quit()
src = sys.argv[1]
dst = sys.argv[2]
model_name = sys.argv[3]
a = AstroSleuth(model_name=model_name)
img = Image.open(src)
r = a.enhance(img)
r.save(dst)