Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

thanks (it's 10x faster than JAX)! #25

Open
Birch-san opened this issue Jun 29, 2022 · 14 comments
Open

thanks (it's 10x faster than JAX)! #25

Birch-san opened this issue Jun 29, 2022 · 14 comments

Comments

@Birch-san
Copy link

Birch-san commented Jun 29, 2022

I've been trying to get dalle-playground running performantly on M1, but there's a lot of work remaining to make the JAX model work via IREE/Vulkan.

so, I tried out your pytorch model,

with a recent nightly of pytorch:

pip install --pre "torch>1.13.0.dev20220610" "torchvision>0.14.0.dev20220609" --extra-index-url https://download.pytorch.org/whl/nightly/cpu

…and it's 10x faster at dalle-mega than dalle-playground was on JAX/XLA!

using dalle-mega full:

wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1:latest

generating 1 image took 27 mins on dalle-playground (using 117% CPU), whereas this pytorch model runs in 2.7 mins (using 145% CPU)!
GPU looks less-than-half utilized. haven't checked whether pytorch is the process that's using the GPU.

these measurements are from M1 Max.

bonus
"crystal maiden and lina enjoying a pint together at a tavern"
generated

@kuprel
Copy link
Owner

kuprel commented Jun 29, 2022

Awesome!

@pcuenca
Copy link

pcuenca commented Jul 2, 2022

generating 1 image took 27 mins on dalle-playground (using 117% CPU), whereas this pytorch model runs in 2.7 mins (using 145% CPU)! GPU looks less-than-half utilized. haven't checked whether pytorch is the process that's using the GPU.

I think the model runs on CPU by default. I tried to move all models and tensors to the mps device and fix some incompatibilities (a few ops are not yet supported by the MPS backend). Inference was faster and GPU utilization was close to 100%, but generation did not work properly. I'm still trying to identify what the problem could be.

@Birch-san
Copy link
Author

Birch-san commented Jul 2, 2022

@pcuenca wait, you got it running on-GPU? and it was faster? that's massively different from the result I got.

here's how I made it run on MPS:
Birch-san/min-dalle@Birch-san:min-dalle:main...Birch-san:min-dalle:mps
there's other stuff in that branch too like generating multiple images, re-using text encoding between images, measuring how long each step takes.

what I found was that it ran way slower. I left it overnight and it didn't finish generating even 1 image (got to the 145th token of 255, something like that).
and tbh the CPU usage (~117%) and GPU usage (less than half) looked identical to when it ran on-CPU.

did I do something wrong? I just slapped device_type on everything I could.
I'm using torch==1.13.0.dev20220628 (recent nightly).
ran with PYTORCH_ENABLE_MPS_FALLBACK=1, --mega --torch --text='kunkka playing basketball with Shrek' --copies=3. with dalle-mega proper, not the fp16 version.
only one operation had to fallback use the fallback-to-CPU, aten::sort.values_stable.

@Birch-san
Copy link
Author

Birch-san commented Jul 2, 2022

generation did not work properly

it's worth knowing that the MPS backend does have some silent errors where it will produce incorrect output (or at least transfer the wrong result to CPU). here's the really wacky phenomenon that I found:
pytorch/pytorch#79383

@pcuenca
Copy link

pcuenca commented Jul 2, 2022

@Birch-san These are my changes so far: main...pcuenca:min-dalle:mps-device

I tried to use workarounds for unsupported ops, except for multinomial. You need to use PYTORCH_ENABLE_MPS_FALLBACK=1 for the backend to automatically fall back to the CPU when it encounters that operation. I also tried to replace it with argmax, which should produce something reasonable, but it did not help with generation.

I may have introduced a problem somewhere, but if you disable the MPS device by returning self here, everything works right.

@pcuenca
Copy link

pcuenca commented Jul 2, 2022

it's worth knowing that the MPS backend does have some silent errors where it will produce incorrect output. here's the really wacky one I found: pytorch/pytorch#79383

That's very interesting. I'll try to debug generation tomorrow. Thanks!

@kuprel
Copy link
Owner

kuprel commented Jul 2, 2022

I was also looking into getting this model on the phone. Apple says that for transformers in pytorch, the dimensions aren't in optimal order for the neural engine: https://machinelearning.apple.com/research/neural-engine-transformers

They also convert all the linear layers to convs and use a different einsum pattern

@Birch-san
Copy link
Author

I was also looking into getting this model on the phone. Apple says that for transformers in pytorch, the dimensions aren't in optimal order for the neural engine: https://machinelearning.apple.com/research/neural-engine-transformers

They also convert all the linear layers to convs and use a different einsum pattern

that's just the neural engine. PyTorch's MPS backend targets the GPU, and JAX's IREE/Vulkan backend does too. Dunno what Tensorflow targets. but I'll definitely take "targeting 48 GPU cores" as a step up from "targeting 10 CPU cores".

it sounds like the Neural Engine is not suitable for training anyway, only inferencing:
pytorch/pytorch#47688 (comment)

@pcuenca
Copy link

pcuenca commented Jul 2, 2022

The neural engine is much faster than the GPU, so it makes sense to apply those optimizations. Not all operations are supported, however, and it's hard to know whether the system decided to run your model in the neural engine or the GPU.

I wasn't trying to do that yet, though. I just wanted to test inference in the MPS backend (GPU) of my M1 mac to see how it compares with the CPU and with nVidia GPUs. If we did a conversion to Core ML, we would then be able to test neural engine inference speed vs PyTorch+MPS performance.

@Birch-san
Copy link
Author

Birch-san commented Jul 2, 2022

@pcuenca

That's very interesting. I'll try to debug generation tomorrow. Thanks!

If it is indeed the problem of transferring from MPS to CPU, then we should try @qqaatw's idea for transferring as contiguous memory.

pytorch/pytorch#79383 (comment)

@Birch-san
Copy link
Author

@pcuenca if I slap .contiguous() at the end of every torch.{reshape,view,unsqueeze,permute}() (i.e. functions which perform reshaping, and which may utilize a view to do so): we get an image that is merely bad rather than pitch-black:
generated generated
kunkka playing basketball with Shrek

Birch-san@8b83231

@Birch-san
Copy link
Author

Birch-san commented Jul 2, 2022

oh, there's one final reshape() that I missed. but adding .contiguous() to that makes things worse rather than better:

generated
kunkka playing basketball with Shrek

Birch-san@43e7e92

@Birch-san
Copy link
Author

I also tried using .contiguous() on any tensor that would be transferred to the MPS device:
Birch-san@b1cf6c2

still black.

@woctezuma
Copy link

woctezuma commented Jul 4, 2022

Even faster these days: you get a 4x4 grid instead of a 3x3 grid on Replicate, after the same duration.

However, this is based on Dall-E MEGA instead of Dall-E Mini, so results might differ. Not sure if better or worse.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants