You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, flax team. When I tried to transfer Pytorch's models to the flax framework, I find flax will consume more memory than Pytorch's. For example, a ResNet50 model in Pytorch will consume 4G GPU memory, while it rises to 6G in flax. I wonder what causes the difference in memory consumption between Pytorch and Flax? Or What can I do to reduce memory usage in Flax? Thanks!
The text was updated successfully, but these errors were encountered:
Hello, flax team. When I tried to transfer Pytorch's models to the flax framework, I find flax will consume more memory than Pytorch's. For example, a ResNet50 model in Pytorch will consume 4G GPU memory, while it rises to 6G in flax. I wonder what causes the difference in memory consumption between Pytorch and Flax? Or What can I do to reduce memory usage in Flax? Thanks!
The text was updated successfully, but these errors were encountered: