Skip to content

add: train to text image with sdxl script. #4505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
0b7fe23
add: train to text image with sdxl script.
sayakpaul Aug 7, 2023
fb14d0d
fix: partial func.
sayakpaul Aug 7, 2023
47ca92f
fix: default value of output_dir.
sayakpaul Aug 7, 2023
9edca86
make style
sayakpaul Aug 7, 2023
166bc1d
set num inference steps to 25.
sayakpaul Aug 7, 2023
f4030d0
remove mentions of LoRA.
sayakpaul Aug 7, 2023
2aa700d
up min version
sayakpaul Aug 7, 2023
ba674d4
add: ema cli arg
sayakpaul Aug 7, 2023
2d24cf7
run device placement while running step.
sayakpaul Aug 7, 2023
be43b12
Merge branch 'main' into feat/training-sdxl-text-to-image
patrickvonplaten Aug 7, 2023
3c94498
Merge branch 'main' into feat/training-sdxl-text-to-image
sayakpaul Aug 7, 2023
f9a2785
Merge branch 'main' into feat/training-sdxl-text-to-image
sayakpaul Aug 7, 2023
e36f1f0
Merge branch 'main' into feat/training-sdxl-text-to-image
sayakpaul Aug 9, 2023
3b1f4d6
Merge branch 'main' into feat/training-sdxl-text-to-image
sayakpaul Aug 9, 2023
cdb99a1
precompute vae encodings too.
sayakpaul Aug 9, 2023
2d12d18
fix
sayakpaul Aug 9, 2023
58f7c29
debug
sayakpaul Aug 9, 2023
97f2c75
should work now.
sayakpaul Aug 9, 2023
fda6c55
debug
sayakpaul Aug 9, 2023
f5e3bca
debug
sayakpaul Aug 9, 2023
1b9e15a
goes alright?
sayakpaul Aug 9, 2023
98cf98e
style
sayakpaul Aug 9, 2023
ad062d3
debugging
sayakpaul Aug 10, 2023
cde4d9d
debugging
sayakpaul Aug 10, 2023
a5045b9
debugging
sayakpaul Aug 10, 2023
ae96cb5
debugging
sayakpaul Aug 10, 2023
5a070cf
fix
sayakpaul Aug 10, 2023
bc0cd57
reinit scheduler if prediction_type was passed.
sayakpaul Aug 10, 2023
ff4bae2
akways cast vae in float32
sayakpaul Aug 10, 2023
41a2580
better handling of snr.
sayakpaul Aug 11, 2023
a4785bf
the vae should be also passed
sayakpaul Aug 11, 2023
1c17f3d
add: docs.
sayakpaul Aug 11, 2023
0347176
add: sdlx t2i tests
sayakpaul Aug 11, 2023
a2b7e8b
save the pipeline
sayakpaul Aug 11, 2023
6053c51
Merge branch 'main' into feat/training-sdxl-text-to-image
sayakpaul Aug 12, 2023
314523e
autocast.
sayakpaul Aug 12, 2023
273b36b
fix: save_model_card
sayakpaul Aug 12, 2023
1fcaf9d
Merge branch 'main' into feat/training-sdxl-text-to-image
sayakpaul Aug 13, 2023
d64ff81
fix: save_model_card.
sayakpaul Aug 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/training/dreambooth.md
Original file line number Diff line number Diff line change
Expand Up @@ -707,4 +707,4 @@ accelerate launch train_dreambooth.py \

## Stable Diffusion XL

We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sdxl.md).
We support fine-tuning of the UNet and text encoders shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sdxl.md).
6 changes: 6 additions & 0 deletions docs/source/en/training/text2image.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,9 @@ image.save("yoda-pokemon.png")
```
</jax>
</frameworkcontent>


## Stable Diffusion XL

* We support fine-tuning the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) via the `train_text_to_image_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
* We also support fine-tuning of the UNet and Text Encoder shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with LoRA via the `train_text_to_image_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
24 changes: 24 additions & 0 deletions examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,30 @@ def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_ch
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)

def test_text_to_image_sdxl(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/text_to_image/train_text_to_image_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))

def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
prompt = "a prompt"
Expand Down
3 changes: 2 additions & 1 deletion examples/text_to_image/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -319,4 +319,5 @@ According to [this issue](https://github.com/huggingface/diffusers/issues/2234#i

## Stable Diffusion XL

We support fine-tuning of the UNet and Text Encoder shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with LoRA via the `train_text_to_image_lora_xl.py` script. Please refer to the docs [here](./README_sdxl.md).
* We support fine-tuning the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) via the `train_text_to_image_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
* We also support fine-tuning of the UNet and Text Encoder shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with LoRA via the `train_text_to_image_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
78 changes: 66 additions & 12 deletions examples/text_to_image/README_sdxl.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
# LoRA training example for Stable Diffusion XL (SDXL)
# Stable Diffusion XL text-to-image fine-tuning

Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
The `train_text_to_image_sdxl.py` script shows how to fine-tune Stable Diffusion XL (SDXL) on your own dataset.

In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:

- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.

[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.

With LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset
on consumer GPUs like Tesla T4, Tesla V100.
🚨 This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset. 🚨

## Running locally with PyTorch

Expand Down Expand Up @@ -57,6 +48,69 @@ When running `accelerate config`, if we specify torch compile mode to True there

### Training

```bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"

accelerate launch train_text_to_image_sdxl.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--pretrained_vae_model_name_or_path=$VAE \
--dataset_name=$DATASET_NAME \
--enable_xformers_memory_efficient_attention \
--resolution=512 --center_crop --random_flip \
--proportion_empty_prompts=0.2 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 --gradient_checkpointing \
--max_train_steps=10000 \
--use_8bit_adam \
--learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \
--mixed_precision="fp16" \
--report_to="wandb" \
--validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \
--checkpointing_steps=5000 \
--output_dir="sdxl-pokemon-model" \
--push_to_hub
```

**Notes**:

* The `train_text_to_image_sdxl.py` script pre-computes text embeddings and the VAE encodings and keeps them in memory. While for smaller datasets like [`lambdalabs/pokemon-blip-captions`](https://hf.co/datasets/lambdalabs/pokemon-blip-captions), it might not be a problem, it can definitely lead to memory problems when the script is used on a larger dataset. For those purposes, you would want to serialize these pre-computed representations to disk separately and load them during the fine-tuning process. Refer to [this PR](https://github.com/huggingface/diffusers/pull/4505) for a more in-depth discussion.
* The training script is compute-intensive and may not run on a consumer GPU like Tesla T4.
* The training command shown above performs intermediate quality validation in between the training epochs and logs the results to Weights and Biases. `--report_to`, `--validation_prompt`, and `--validation_epochs` are the relevant CLI arguments here.

### Inference

```python
from diffusers import DiffusionPipeline
import torch

model_path = "you-model-id-goes-here" # <-- change this
pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to("cuda")

prompt = "A pokemon with green eyes and red legs."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("pokemon.png")
```

## LoRA training example for Stable Diffusion XL (SDXL)

Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.

In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:

- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.

[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.

With LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset
on consumer GPUs like Tesla T4, Tesla V100.

### Training

First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).

**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___**
Expand Down
2 changes: 1 addition & 1 deletion examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
"""Fine-tuning script for Stable Diffusion XL for text2image with support for LoRA."""

import argparse
import itertools
Expand Down
Loading