Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU OOM if model ran in Python multithreading #101

Open
xcharleslin opened this issue Feb 22, 2023 · 0 comments
Open

GPU OOM if model ran in Python multithreading #101

xcharleslin opened this issue Feb 22, 2023 · 0 comments

Comments

@xcharleslin
Copy link

xcharleslin commented Feb 22, 2023

Minimum repro:

import torch
from min_dalle import MinDalle
from concurrent.futures import ThreadPoolExecutor

USE_GPU = True
def f(text: str, root: str):
    return MinDalle(
        models_root=f'./{root}',
        dtype=torch.float32,
        device="cuda",
        is_mega=False, 
        is_reusable=True,
    ).generate_image(
        text,
        seed=-1,
        grid_size=1,
        is_seamless=False,
        temperature=1,
        top_k=256,
        supercondition_factor=32,
    )

# No threading works
f("hello", "root1")  

# Threading does not work
tpe = ThreadPoolExecutor()
tpe.submit(f, "hello2", "root2").result()  # GPU OOMs here

The last line fails with OutOfMemoryError: CUDA out of memory.

(click for full stack trace)
using device cuda
downloading tokenizer params
intializing TextTokenizer
downloading encoder params
initializing DalleBartEncoder
downloading decoder params
initializing DalleBartDecoder
downloading detokenizer params
initializing VQGanDetokenizer
---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
<ipython-input-11-7b2fe8b33ac0> in <module>
      5 fut = tpe.submit(f, "abc", "def")
      6 
----> 7 fut.result()

12 frames
/usr/lib/python3.8/concurrent/futures/_base.py in result(self, timeout)
    442                     raise CancelledError()
    443                 elif self._state == FINISHED:
--> 444                     return self.__get_result()
    445                 else:
    446                     raise TimeoutError()

/usr/lib/python3.8/concurrent/futures/_base.py in __get_result(self)
    387         if self._exception:
    388             try:
--> 389                 raise self._exception
    390             finally:
    391                 # Break a reference cycle with the exception in self._exception

/usr/lib/python3.8/concurrent/futures/thread.py in run(self)
     55 
     56         try:
---> 57             result = self.fn(*self.args, **self.kwargs)
     58         except BaseException as exc:
     59             self.future.set_exception(exc)

<ipython-input-9-7e6467e7527a> in f(text, dir)
      5 USE_GPU = True
      6 def f(text: str, dir: str) -> PIL.Image.Image:
----> 7     return MinDalle(
      8         models_root=f'./{dir}',
      9         dtype=torch.float32,

/usr/local/lib/python3.8/dist-packages/min_dalle/min_dalle.py in generate_image(self, *args, **kwargs)
    279             progressive_outputs=False
    280         )
--> 281         return next(image_stream)
    282 
    283 

/usr/local/lib/python3.8/dist-packages/min_dalle/min_dalle.py in generate_image_stream(self, *args, **kwargs)
    259     def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]:
    260         image_stream = self.generate_raw_image_stream(*args, **kwargs)
--> 261         for image in image_stream:
    262             image = image.to(torch.uint8).to('cpu').numpy()
    263             yield Image.fromarray(image)

/usr/local/lib/python3.8/dist-packages/min_dalle/min_dalle.py in generate_raw_image_stream(self, text, seed, grid_size, progressive_outputs, is_seamless, temperature, top_k, supercondition_factor, is_verbose)
    238             torch.cuda.empty_cache()
    239             with torch.cuda.amp.autocast(dtype=self.dtype):
--> 240                 image_tokens[:, i + 1], attention_state = self.decoder.sample_tokens(
    241                     settings=settings,
    242                     attention_mask=attention_mask,

/usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in sample_tokens(self, settings, **kwargs)
    175 
    176     def sample_tokens(self, settings, **kwargs) -> Tuple[LongTensor, FloatTensor]:
--> 177         logits, attention_state = self.forward(**kwargs)
    178         image_count = logits.shape[0] // 2
    179         temperature = settings[[0]]

/usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in forward(self, attention_mask, encoder_state, attention_state, prev_tokens, token_index)
    162         decoder_state = self.layernorm_embedding.forward(decoder_state)
    163         for i in range(self.layer_count):
--> 164             decoder_state, attention_state[i] = self.layers[i].forward(
    165                 decoder_state,
    166                 encoder_state,

/usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in forward(self, decoder_state, encoder_state, attention_state, attention_mask, token_index)
     88         residual = decoder_state
     89         decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
---> 90         decoder_state, attention_state = self.self_attn.forward(
     91             decoder_state=decoder_state,
     92             attention_state=attention_state,

/usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in forward(self, decoder_state, attention_state, attention_mask, token_index)
     43             values = attention_state[batch_count:]
     44 
---> 45         decoder_state = super().forward(keys, values, queries, attention_mask)
     46         return decoder_state, attention_state
     47 

/usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_encoder.py in forward(self, keys, values, queries, attention_mask)
     47         queries /= queries.shape[-1] ** 0.5
     48         attention_bias = (1 - attention_mask.to(torch.float32)) * -1e12
---> 49         attention_weights: FloatTensor = torch.einsum(
     50             'bqhc,bkhc->bhqk',
     51             queries,

/usr/local/lib/python3.8/dist-packages/torch/functional.py in einsum(*args)
    376         # the path for contracting 0 or 1 time(s) is already optimized
    377         # or the user has disabled using opt_einsum
--> 378         return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
    379 
    380     path = None

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 14.76 GiB total capacity; 13.67 GiB already allocated; 17.88 MiB free; 13.69 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant