-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Torch2.0 scaled_dot_product_attention processor #2303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
|
Very cool! Thanks a lot for working on this @patil-suraj . Nice to see all the results here :-) When adding this to
scaled_dot_product_attention that we currently have?
|
|
Thanks @patil-suraj , @patrickvonplaten for the quick turn around. I wonder if the plan is to merge when Pt 2.0 released? or we can merge now and ask people to install nightlies if wanted to use? |
|
@HamidShojanazeri the plan is to merge it by end of this week :) |
In terms of functionality everything is supported, except when My only concern would be reproducibility and numerical accuracy. From the docs
And
Wonder if it might require us to change out tests slightly ? cc @patrickvonplaten @williamberman @pcuenca |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks very nice to me! Would be nice to also get some numbers on fp32 for completeness and some quick training comparison stats
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, would be nice to add the perf table as well - the community really like this I think
Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
|
Added the benchmark table and changed the name to |
pcuenca
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
* add sdpa processor * don't use it by default * add some checks and style * typo * support torch sdpa in dreambooth example * use torch attn proc by default when available * typo * add attn mask * fix naming * being doc * doc * Apply suggestions from code review * polish * torctree * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> * better name * style * add benchamrk table * Update docs/source/en/optimization/torch2.0.mdx * up * fix example * check if processor is None * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * add fp32 benchmakr * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
* add sdpa processor * don't use it by default * add some checks and style * typo * support torch sdpa in dreambooth example * use torch attn proc by default when available * typo * add attn mask * fix naming * being doc * doc * Apply suggestions from code review * polish * torctree * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> * better name * style * add benchamrk table * Update docs/source/en/optimization/torch2.0.mdx * up * fix example * check if processor is None * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * add fp32 benchmakr * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
* add sdpa processor * don't use it by default * add some checks and style * typo * support torch sdpa in dreambooth example * use torch attn proc by default when available * typo * add attn mask * fix naming * being doc * doc * Apply suggestions from code review * polish * torctree * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> * better name * style * add benchamrk table * Update docs/source/en/optimization/torch2.0.mdx * up * fix example * check if processor is None * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * add fp32 benchmakr * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
This PR adds support for the efficient attention from torch2.0. Flash and memory efficient attention (from xformers) is
now built into n the latest torch nightlies and can be used without any extra dep (c.f https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention).
This PR adds
AttnProccesor2_0class which implements efficient attention viatorch.nn.functional.scaled_dot_product_attention.The
AttnProccesor2_0will be used by default whever torch2.0 is installed andscaled_dot_product_attentionis available.To test:
install CUDA 11.7 latest pytorch nightlies and
To benchmark:
The script used for benchmarking.
To benchmark xFormers
xFormersWIP benchmark:
Bechmark is done using the following config:
Model: CompVis/stable-diffusion-v1-4, steps=50, dtype=fp16xFormers benchmark is done using the
torch==1.13.1versionThe time reported is in seconds.
(1) Batch Size >= 32 requires enable_vae_slicing() because of pytorch/pytorch#81665
This is required for PyTorch 1.13.1, and also for PyTorch 2.0 and batch size of 64
(2) Got recompilation warnings from batch_size 10 with torch.compile: