Skip to content

Commit bdd961b

Browse files
committed
Control the validation steps separately from the checkpointing steps.
1 parent 21441f5 commit bdd961b

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

examples/text_to_image/train_text_to_image.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,14 @@ def parse_args():
297297
default=4,
298298
help="Number of images that should be generated during validation with `validation_prompt`.",
299299
)
300+
parser.add_argument(
301+
"--validation_steps",
302+
type=int,
303+
default=500,
304+
help=(
305+
"Sample a validation image every X updates."
306+
),
307+
)
300308

301309
args = parser.parse_args()
302310
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -710,6 +718,8 @@ def collate_fn(examples):
710718
accelerator.save_state(save_path)
711719
logger.info(f"Saved state to {save_path}")
712720

721+
if global_step % args.validation_steps == 0:
722+
if accelerator.is_main_process:
713723
if args.validation_prompt:
714724
pipeline = StableDiffusionPipeline.from_pretrained(
715725
args.pretrained_model_name_or_path,

0 commit comments

Comments
 (0)