We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Minimum repro:
import torch from min_dalle import MinDalle from concurrent.futures import ThreadPoolExecutor USE_GPU = True def f(text: str, root: str): return MinDalle( models_root=f'./{root}', dtype=torch.float32, device="cuda", is_mega=False, is_reusable=True, ).generate_image( text, seed=-1, grid_size=1, is_seamless=False, temperature=1, top_k=256, supercondition_factor=32, ) # No threading works f("hello", "root1") # Threading does not work tpe = ThreadPoolExecutor() tpe.submit(f, "hello2", "root2").result() # GPU OOMs here
The last line fails with OutOfMemoryError: CUDA out of memory.
OutOfMemoryError: CUDA out of memory.
using device cuda downloading tokenizer params intializing TextTokenizer downloading encoder params initializing DalleBartEncoder downloading decoder params initializing DalleBartDecoder downloading detokenizer params initializing VQGanDetokenizer --------------------------------------------------------------------------- OutOfMemoryError Traceback (most recent call last) <ipython-input-11-7b2fe8b33ac0> in <module> 5 fut = tpe.submit(f, "abc", "def") 6 ----> 7 fut.result() 12 frames /usr/lib/python3.8/concurrent/futures/_base.py in result(self, timeout) 442 raise CancelledError() 443 elif self._state == FINISHED: --> 444 return self.__get_result() 445 else: 446 raise TimeoutError() /usr/lib/python3.8/concurrent/futures/_base.py in __get_result(self) 387 if self._exception: 388 try: --> 389 raise self._exception 390 finally: 391 # Break a reference cycle with the exception in self._exception /usr/lib/python3.8/concurrent/futures/thread.py in run(self) 55 56 try: ---> 57 result = self.fn(*self.args, **self.kwargs) 58 except BaseException as exc: 59 self.future.set_exception(exc) <ipython-input-9-7e6467e7527a> in f(text, dir) 5 USE_GPU = True 6 def f(text: str, dir: str) -> PIL.Image.Image: ----> 7 return MinDalle( 8 models_root=f'./{dir}', 9 dtype=torch.float32, /usr/local/lib/python3.8/dist-packages/min_dalle/min_dalle.py in generate_image(self, *args, **kwargs) 279 progressive_outputs=False 280 ) --> 281 return next(image_stream) 282 283 /usr/local/lib/python3.8/dist-packages/min_dalle/min_dalle.py in generate_image_stream(self, *args, **kwargs) 259 def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]: 260 image_stream = self.generate_raw_image_stream(*args, **kwargs) --> 261 for image in image_stream: 262 image = image.to(torch.uint8).to('cpu').numpy() 263 yield Image.fromarray(image) /usr/local/lib/python3.8/dist-packages/min_dalle/min_dalle.py in generate_raw_image_stream(self, text, seed, grid_size, progressive_outputs, is_seamless, temperature, top_k, supercondition_factor, is_verbose) 238 torch.cuda.empty_cache() 239 with torch.cuda.amp.autocast(dtype=self.dtype): --> 240 image_tokens[:, i + 1], attention_state = self.decoder.sample_tokens( 241 settings=settings, 242 attention_mask=attention_mask, /usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in sample_tokens(self, settings, **kwargs) 175 176 def sample_tokens(self, settings, **kwargs) -> Tuple[LongTensor, FloatTensor]: --> 177 logits, attention_state = self.forward(**kwargs) 178 image_count = logits.shape[0] // 2 179 temperature = settings[[0]] /usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in forward(self, attention_mask, encoder_state, attention_state, prev_tokens, token_index) 162 decoder_state = self.layernorm_embedding.forward(decoder_state) 163 for i in range(self.layer_count): --> 164 decoder_state, attention_state[i] = self.layers[i].forward( 165 decoder_state, 166 encoder_state, /usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in forward(self, decoder_state, encoder_state, attention_state, attention_mask, token_index) 88 residual = decoder_state 89 decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) ---> 90 decoder_state, attention_state = self.self_attn.forward( 91 decoder_state=decoder_state, 92 attention_state=attention_state, /usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in forward(self, decoder_state, attention_state, attention_mask, token_index) 43 values = attention_state[batch_count:] 44 ---> 45 decoder_state = super().forward(keys, values, queries, attention_mask) 46 return decoder_state, attention_state 47 /usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_encoder.py in forward(self, keys, values, queries, attention_mask) 47 queries /= queries.shape[-1] ** 0.5 48 attention_bias = (1 - attention_mask.to(torch.float32)) * -1e12 ---> 49 attention_weights: FloatTensor = torch.einsum( 50 'bqhc,bkhc->bhqk', 51 queries, /usr/local/lib/python3.8/dist-packages/torch/functional.py in einsum(*args) 376 # the path for contracting 0 or 1 time(s) is already optimized 377 # or the user has disabled using opt_einsum --> 378 return _VF.einsum(equation, operands) # type: ignore[attr-defined] 379 380 path = None OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 14.76 GiB total capacity; 13.67 GiB already allocated; 17.88 MiB free; 13.69 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Minimum repro:
The last line fails with
OutOfMemoryError: CUDA out of memory.
(click for full stack trace)
The text was updated successfully, but these errors were encountered: