-
Notifications
You must be signed in to change notification settings - Fork 9
/
seemore.py
360 lines (308 loc) · 14.7 KB
/
seemore.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
import base64
import io
import pandas as pd
from PIL import Image
import torchvision.transforms as transforms
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 100
eval_interval = 10
learning_rate = 1e-3
epochs=1
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 40
num_blks= 3
head_size = 16
n_embd = 128
n_head = 8
n_layer = 8
dropout = 0.1
img_size=96
patch_size =16
image_embed_dim = 512
emb_dropout = blk_dropout =0.1
# Ensure every computation happens on the GPU when available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#To build the encoding and decoding functions we use the tinyshakespear dataset. However for the sake of brevity we do not pretrain the decoder model on it
#the training function should be able to do it without an issue as well as it could take both images and text
text_path = "./input.txt"
with open(text_path, 'r', encoding='utf-8') as f:
text = f.read()
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
stoi['<pad>']= 65
itos = { i:ch for i,ch in enumerate(chars) }
itos[65] = '<pad>'
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
vocab_size = len(stoi.keys())
class PatchEmbeddings(nn.Module):
def __init__(self, img_size=96, patch_size=16, hidden_dim=512):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# Ensure the convolution outputs a feature map with hidden_dim channels
self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim,
kernel_size=patch_size, stride=patch_size)
def forward(self, X):
X = self.conv(X)
X = X.flatten(2) # Flatten the patch dimensions
X = X.transpose(1, 2) # [B, num_patches, hidden_dim]
return X
#swapping linear for lazy linear for simplicity. Lazylinear can accept any arbitrary input dimension without having it specified
class MLP(nn.Module):
def __init__(self, n_embd, dropout=0.1, is_decoder=True):
super().__init__()
layers = [
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU() if is_decoder else nn.GELU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout)
]
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class Head(nn.Module):
def __init__(self, n_embd, head_size, dropout=0.1, is_decoder=False):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.dropout = nn.Dropout(dropout)
self.is_decoder = is_decoder
def forward(self, x):
B, T, C = x.shape
k = self.key(x)
q = self.query(x)
v = self.value(x)
# Compute attention scores
wei = q @ k.transpose(-2, -1) * (C**-0.5)
if self.is_decoder:
# Ensure the mask is the correct size for the current sequence length
tril = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device))
wei = wei.masked_fill(tril == 0, float('-inf'))
# Apply softmax to get probabilities
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
# Perform weighted aggregation of values
out = wei @ v
return out
class MultiHeadAttention(nn.Module):
def __init__(self, n_embd, num_heads, dropout=0.1, is_decoder=False):
super().__init__()
#Using assert statements for this type of checks is a good idea in general in your code
assert n_embd % num_heads == 0, "n_embd must be divisible by num_heads"
self.heads = nn.ModuleList([
Head(n_embd, n_embd // num_heads, dropout, is_decoder)
for _ in range(num_heads)
])
self.proj = nn.Linear(n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
class Block(nn.Module):
def __init__(self, n_embd, num_heads, dropout=0.1, is_decoder=False):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.attn = MultiHeadAttention(n_embd, num_heads, dropout, is_decoder)
self.ln2 = nn.LayerNorm(n_embd)
self.ffn = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(),
nn.Linear(4 * n_embd, n_embd),
)
def forward(self, x):
original_x = x # Save for residual connection
x = self.ln1(x)
attn_output = self.attn(x)
x = original_x + attn_output
x = self.ln2(x)
ffn_output = self.ffn(x)
x = x + ffn_output
return x
class ViT(nn.Module):
def __init__(self, img_size, patch_size, num_hiddens, num_heads, num_blks, emb_dropout, blk_dropout):
super().__init__()
self.patch_embedding = PatchEmbeddings(img_size, patch_size, num_hiddens)
self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens))
num_patches = (img_size // patch_size) ** 2
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens))
self.dropout = nn.Dropout(emb_dropout)
self.blocks = nn.ModuleList([Block(num_hiddens, num_heads, blk_dropout, is_decoder=False) for _ in range(num_blks)])
self.layer_norm = nn.LayerNorm(num_hiddens)
def forward(self, X):
x = self.patch_embedding(X)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x = self.dropout(x)
for block in self.blocks:
x = block(x)
x = self.layer_norm(x[:, 0])
return x
class MultiModalProjector(nn.Module):
def __init__(self, n_embd, image_embed_dim, dropout=0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(image_embed_dim, 4 * image_embed_dim),
nn.GELU(),
nn.Linear(4 * image_embed_dim, n_embd),
nn.Dropout(dropout)
)
def forward(self, x):
x = self.net(x)
return x
class DecoderLanguageModel(nn.Module):
def __init__(self, n_embd, image_embed_dim, vocab_size, num_heads, n_layer, use_images=False):
super().__init__()
self.use_images = use_images
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(1000, n_embd)
if use_images:
self.image_projection = MultiModalProjector(n_embd, image_embed_dim)
self.blocks = nn.Sequential(*[Block(n_embd, num_heads, is_decoder=True) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx, image_embeds=None, targets=None):
tok_emb = self.token_embedding_table(idx)
if self.use_images and image_embeds is not None:
img_emb = self.image_projection(image_embeds).unsqueeze(1)
tok_emb = torch.cat([img_emb, tok_emb], dim=1)
pos_emb = self.position_embedding_table(torch.arange(tok_emb.size(1), device=device)).unsqueeze(0)
x = tok_emb + pos_emb
x = self.blocks(x)
x = self.ln_f(x)
logits = self.lm_head(x)
if targets is not None:
if self.use_images and image_embeds is not None:
batch_size = idx.size(0)
targets = torch.cat([torch.full((batch_size, 1), -100, dtype=torch.long, device=device), targets], dim=1)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
return logits, loss
return logits
def generate(self, idx, image_embeds, max_new_tokens):
B, T = idx.shape
generated = idx
if self.use_images and image_embeds is not None:
img_emb = self.image_projection(image_embeds).unsqueeze(1)
current_output = torch.cat([img_emb, self.token_embedding_table(idx)], dim=1)
else:
current_output = self.token_embedding_table(idx)
for i in range(max_new_tokens):
T_current = current_output.size(1)
current_pos_emb = self.position_embedding_table(torch.arange(T_current, device=device)).unsqueeze(0)
current_output += current_pos_emb
for block in self.blocks:
current_output = block(current_output)
logits = self.lm_head(current_output[:, -1, :])
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
generated = torch.cat((generated, idx_next), dim=1)
idx_next_emb = self.token_embedding_table(idx_next)
current_output = torch.cat((current_output, idx_next_emb), dim=1)
return generated
class VisionLanguageModel(nn.Module):
def __init__(self, n_embd, image_embed_dim, vocab_size, n_layer, img_size, patch_size, num_heads, num_blks, emb_dropout, blk_dropout):
super().__init__()
num_hiddens = image_embed_dim # Set num_hiddens equal to image_embed_dim
assert num_hiddens % num_heads == 0, "num_hiddens must be divisible by num_heads"
self.vision_encoder = ViT(img_size, patch_size, num_hiddens, num_heads, num_blks, emb_dropout, blk_dropout)
self.decoder = DecoderLanguageModel(n_embd, image_embed_dim, vocab_size, num_heads, n_layer, use_images=True)
def forward(self, img_array, idx, targets=None):
image_embeds = self.vision_encoder(img_array)
if image_embeds.nelement() == 0 or image_embeds.shape[1] == 0:
raise ValueError("somethign is messed up with the ViT model. It's returning an empty tensor or the embedding dimension is empty")
if targets is not None:
logits, loss = self.decoder(idx, image_embeds, targets)
return logits, loss
else:
logits = self.decoder(idx, image_embeds)
return logits
def generate(self, img_array, idx, max_new_tokens):
image_embeds = self.vision_encoder(img_array)
if image_embeds.nelement() == 0 or image_embeds.shape[1] ==0:
raise ValueError("somethign is messed up with the ViT model. It's returning an empty tensor or the embedding dimension is empty")
generated_tokens = self.decoder.generate(idx, image_embeds, max_new_tokens)
return generated_tokens
def base64_to_tensor(base64_str, img_size=96):
image = Image.open(io.BytesIO(base64.b64decode(base64_str)))
if image.mode != 'RGB':
image = image.convert('RGB')
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0) # Add batch dimension
#Adjusting the data loader from makemore for multimodal data
def get_batch(df, batch_size, split='train', img_size=96, val_batch_size=8):
# Split data into training and validation sets
n = int(0.9 * len(df)) # first 90% will be train, rest val
df_train = df.iloc[:n]
df_val = df.iloc[n:]
data = df_train if split == 'train' else df_val
batch_size = batch_size if split == 'train' else val_batch_size
replace = False if split == 'train' else True
batch = data.sample(n=batch_size, replace=replace)
images = torch.cat([base64_to_tensor(img, img_size) for img in batch['b64string_images']], dim=0).to(device)
text_indices = [torch.tensor(encode(desc), dtype=torch.long) for desc in batch['caption']]
max_length = max(len(t) for t in text_indices)
padded_text = torch.full((batch_size, max_length), fill_value=stoi['<pad>'], dtype=torch.long).to(device)
for i, text in enumerate(text_indices):
padded_text[i, :len(text)] = text
targets = torch.cat([padded_text[:, 1:], torch.full((batch_size, 1), fill_value=stoi['<pad>'], dtype=torch.long, device=device)], dim=1)
# Truncate or pad targets to match the length of padded_text
if targets.size(1) > padded_text.size(1):
targets = targets[:, :padded_text.size(1)]
elif targets.size(1) < padded_text.size(1):
targets = torch.cat([targets, torch.full((batch_size, padded_text.size(1) - targets.size(1)), fill_value=stoi['<pad>'], dtype=torch.long, device=device)], dim=1)
return images, padded_text, targets
#Adjusting the training loop from makemore for multimodal data
def train_model(model, df, epochs, vocab_size, img_size=96):
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
model.to(device)
for epoch in range(epochs):
model.train()
for _ in range(max_iters):
images, idx, targets = get_batch(df, batch_size, 'train', img_size)
optimizer.zero_grad()
logits, loss = model(images, idx, targets)
loss.backward()
optimizer.step()
if _ % eval_interval == 0:
print(f"Loss at iteration {_}: {loss.item()}")
val_loss = estimate_loss(model, df, 'val', img_size, val_batch_size=8)
print(f"Validation Loss after epoch {epoch}: {val_loss}")
def estimate_loss(model, df, split, img_size=96, val_batch_size=8):
losses = []
model.eval()
for _ in range(eval_iters):
images, idx, targets = get_batch(df, batch_size, split, img_size, val_batch_size=val_batch_size)
_, loss = model(images, idx, targets)
losses.append(loss.item())
return sum(losses) / len(losses)
def main():
# Load the dataset
df = pd.read_csv("./inputs.csv")
#Expanding dataframe so that there's enough data to test. This is just duplicating data. A real dataset would have more rows
df = pd.concat([df] * 30)[['b64string_images', 'caption']]
# Initialize the model
model = VisionLanguageModel(n_embd, image_embed_dim, vocab_size, n_layer, img_size, patch_size, n_head, num_blks, emb_dropout, blk_dropout)
model.to(device)
# Dummy data to initialize lazy modules
dummy_img = torch.randn(1, 3, img_size, img_size).to(device)
dummy_idx = torch.randint(0, vocab_size, (1, block_size)).to(device)
model(dummy_img, dummy_idx) # Forward pass to initialize all parameters
# Train the model
train_model(model, df, epochs, vocab_size, img_size)
if __name__ == "__main__":
main()