Skip to content

Conversation

@patil-suraj
Copy link
Contributor

@patil-suraj patil-suraj commented Feb 9, 2023

This PR adds support for the efficient attention from torch2.0. Flash and memory efficient attention (from xformers) is
now built into n the latest torch nightlies and can be used without any extra dep (c.f https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention).

This PR adds AttnProccesor2_0 class which implements efficient attention via torch.nn.functional.scaled_dot_product_attention.

The AttnProccesor2_0 will be used by default whever torch2.0 is installed and scaled_dot_product_attention is available.

To test:

install CUDA 11.7 latest pytorch nightlies and

import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler

model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None).to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.unet = torch.compile(pipe.unet)

batch_size = 10
prompt =  "A photo of an astronaut riding a horse on marse."
images = pipe(prompt,  num_inference_steps=steps, num_images_per_prompt=batch_size).images

To benchmark:

  • Install CUDA11.7
  • Install PT2 with
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu117
  • Install diffusers using
pip install git+https://github.com/huggingface/[email protected] transformers accelerate

The script used for benchmarking.

import torch
import torch.utils.benchmark as benchmark
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.models.cross_attention import CrossAttnProcessor, AttnProccesor2_0

def benchmark_torch_function(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return round(t0.blocked_autorange(min_run_time=1).mean, 2)

# benchmark code
model_id = "CompVis/stable-diffusion-v1-4"
prompt = "A photo of an astronaut riding a horse on mars."
steps = 50
batch_size = 10
dtype = torch.float16

# load model
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype, safety_checker=None).to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.set_progress_bar_config(disable=True)
if batch_size >= 32:
    pipe.enable_vae_slicing()

# Vanilla Cross Attention
print("Running benchmark for vanilla cross attention...")
pipe.unet.set_attn_processor(CrossAttnProcessor())
f = lambda : pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images
time_vanilla = benchmark_torch_function(f)

# PyTorch sdpa
print("Running benchmark for PyTorch SDPA...")
pipe.unet.set_attn_processor(AttnProccesor2_0())
f = lambda : pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images
time_sdpa = benchmark_torch_function(f)

# PyTorch sdpa with torch.compile
print("Running benchmark for PyTorch SDPA with torch.compile...")
pipe.unet = torch.compile(pipe.unet)
# warmup
pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images
f = lambda : pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images
time_sdpa_torch_compile = benchmark_torch_function(f)

# print results with nice formatting
print(f"Model: {model_id}, dtype: {dtype}, steps: {steps}, batch_size: {batch_size}")
print(f"Vanilla Cross Attention:         {time_vanilla} s")
print(f"PyTorch SDPA:                    {time_sdpa} s")
print(f"PyTorch SDPA with torch.compile: {time_sdpa_torch_compile} s")

To benchmark xFormers

  • Install PT 1.13.1
pip3 install torch torchvision
  • Install xFormers
pip install xformers
import torch
import torch.utils.benchmark as benchmark
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler

def benchmark_torch_function(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return round(t0.blocked_autorange(min_run_time=1).mean, 2)

# benchmark code
model_id = "CompVis/stable-diffusion-v1-4"
prompt = "A photo of an astronaut riding a horse on mars."
steps = 50
batch_size = 16
dtype = torch.float16

# load model
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype, safety_checker=None).to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.set_progress_bar_config(disable=True)

# xFormers
print("Running benchmark for xFormers...")
pipe.enable_xformers_memory_efficient_attention()
f = lambda : pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images
time_xformers = benchmark_torch_function(f)

# print results with nice formatting
print(f"Model: {model_id}, dtype: {dtype}, steps: {steps}, batch_size: {batch_size}")
print(f"xFormers: {time_xformers} s")

WIP benchmark:

Bechmark is done using the following config:

  • Model: CompVis/stable-diffusion-v1-4, steps=50, dtype=fp16

  • xFormers benchmark is done using the torch==1.13.1 version

The time reported is in seconds.

GPU Batch Size Vanilla Attention xFormers PyTorch2.0 SDPA SDPA + torch.compile Speed over xformers (%)
A100 10 12.02 8.7 8.79 7.89 9.31
A100 16 18.95 13.57 13.67 12.25 9.73
A100 32 (1) OOM 26.56 26.68 24.08 9.34
A100 64(2) 52.51 53.03 47.81 8.95
T4 4 38.81 30.09 29.74 27.55 8.44
T4 8 OOM 55.71 55.99 53.85 3.34
T4 10 OOM 68.96 69.86 65.35 5.23
T4 16 OOM 111.47 113.26 106.93 4.07
V100 4 9.84 8.16 8.09 7.65 6.25
V100 8 OOM 15.62 15.44 14.59 6.59
V100 10 OOM 19.52 19.28 18.18 6.86
V100 16 OOM 30.29 29.84 28.22 6.83
A10 4 13.94 9.81 10.01 9.35 4.69
A10 8 27.09 19 19.53 18.33 3.53
A10 10 33.69 23.53 24.19 22.52 4.29
A10 16 OOM 37.55 38.31 36.81 1.97
A10 32 (1) 77.19 78.43 76.64 0.71
A10 64 (1) 173.59 158.99 155.14 10.63
3090 4 10.04 7.82 7.89 7.47 4.48
3090 8 19.27 14.97 15.04 14.22 5.01
3090 10 (2) 24.08 18.7 18.7 17.69 5.40
3090 16 OOM 29.06 29.06 28.2 2.96
3090 32 (1) 58.05 58 54.88 5.46
3090 64 (1) 126.54 126.03 117.33 7.28
3090 Ti 4 9.07 7.14 7.15 6.81 4.62
3090 Ti 8 17.51 13.65 13.72 12.99 4.84
3090 Ti 10 (2) 21.79 16.85 16.93 16.02 4.93
3090 Ti 16 OOM 26.1 26.28 25.46 2.45
3090 Ti 32 (1) 51.78 52.04 49.15 5.08
3090 Ti 64 (1) 112.02 112.33 103.91 7.24

(1) Batch Size >= 32 requires enable_vae_slicing() because of pytorch/pytorch#81665
This is required for PyTorch 1.13.1, and also for PyTorch 2.0 and batch size of 64

(2) Got recompilation warnings from batch_size 10 with torch.compile:

[2023-02-10 13:02:03,854] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)																										
function: 'forward' (/home/pedro/code/hf/diffusers/diffusers/src/diffusers/models/cross_attention.py:175)																										
reasons:  ___guarded_code.valid																										
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Feb 9, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

Very cool! Thanks a lot for working on this @patil-suraj . Nice to see all the results here :-)

When adding this to diffusers, can we make it the default attention processor class here:

processor = processor if processor is not None else CrossAttnProcessor()
if we detect PyTorch 2.0 is used ? Or is there any functionality that is not yet supported by scaled_dot_product_attention that we currently have?

@HamidShojanazeri
Copy link

Thanks @patil-suraj , @patrickvonplaten for the quick turn around. I wonder if the plan is to merge when Pt 2.0 released? or we can merge now and ask people to install nightlies if wanted to use?

@patil-suraj
Copy link
Contributor Author

@HamidShojanazeri the plan is to merge it by end of this week :)

@patil-suraj
Copy link
Contributor Author

patil-suraj commented Feb 14, 2023

Very cool! Thanks a lot for working on this @patil-suraj . Nice to see all the results here :-)

When adding this to diffusers, can we make it the default attention processor class here:

processor = processor if processor is not None else CrossAttnProcessor()

if we detect PyTorch 2.0 is used ? Or is there any functionality that is not yet supported by scaled_dot_product_attention that we currently have?

In terms of functionality everything is supported, except when attention_mask is passed sdpa will use the vanilla attention and not the memory efficient one. But we do the same for xformers now, when using xformers, attention mask is not supported. So I think we can definitely use it by default when available :)

My only concern would be reproducibility and numerical accuracy. From the docs

Due to the nature of fusing floating point operations, the output of this function may be different depending on what backend kernel is chosen. The c++ implementation supports torch.float64 and can be used when higher precision is required. For more information please see Numerical accuracy

And

In some circumstances when given tensors on a CUDA device and using CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is undesirable, you can try to make the operation deterministic (potentially at a performance cost) by setting torch.backends.cudnn.deterministic = True. See Reproducibility for more information.

Wonder if it might require us to change out tests slightly ? cc @patrickvonplaten @williamberman @pcuenca

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very nice to me! Would be nice to also get some numbers on fp32 for completeness and some quick training comparison stats

@patil-suraj patil-suraj marked this pull request as ready for review February 17, 2023 09:44
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, would be nice to add the perf table as well - the community really like this I think

@patil-suraj
Copy link
Contributor Author

patil-suraj commented Feb 17, 2023

Added the benchmark table and changed the name to AttnProccesor2_0 :)

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

@patrickvonplaten patrickvonplaten merged commit 0c0bb08 into main Feb 17, 2023
@patil-suraj patil-suraj deleted the torch2.0 branch February 17, 2023 14:18
mengfei25 pushed a commit to mengfei25/diffusers that referenced this pull request Mar 27, 2023
* add sdpa processor

* don't use it by default

* add some checks and style

* typo

* support torch sdpa in dreambooth example

* use torch attn proc by default when available

* typo

* add attn mask

* fix naming

* being doc

* doc

* Apply suggestions from code review

* polish

* torctree

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* better name

* style

* add benchamrk table

* Update docs/source/en/optimization/torch2.0.mdx

* up

* fix example

* check if processor is None

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

* add fp32 benchmakr

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* add sdpa processor

* don't use it by default

* add some checks and style

* typo

* support torch sdpa in dreambooth example

* use torch attn proc by default when available

* typo

* add attn mask

* fix naming

* being doc

* doc

* Apply suggestions from code review

* polish

* torctree

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* better name

* style

* add benchamrk table

* Update docs/source/en/optimization/torch2.0.mdx

* up

* fix example

* check if processor is None

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

* add fp32 benchmakr

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* add sdpa processor

* don't use it by default

* add some checks and style

* typo

* support torch sdpa in dreambooth example

* use torch attn proc by default when available

* typo

* add attn mask

* fix naming

* being doc

* doc

* Apply suggestions from code review

* polish

* torctree

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* better name

* style

* add benchamrk table

* Update docs/source/en/optimization/torch2.0.mdx

* up

* fix example

* check if processor is None

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

* add fp32 benchmakr

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants