File tree Expand file tree Collapse file tree 1 file changed +10
-0
lines changed
Expand file tree Collapse file tree 1 file changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -297,6 +297,14 @@ def parse_args():
297297 default = 4 ,
298298 help = "Number of images that should be generated during validation with `validation_prompt`." ,
299299 )
300+ parser .add_argument (
301+ "--validation_steps" ,
302+ type = int ,
303+ default = 500 ,
304+ help = (
305+ "Sample a validation image every X updates."
306+ ),
307+ )
300308
301309 args = parser .parse_args ()
302310 env_local_rank = int (os .environ .get ("LOCAL_RANK" , - 1 ))
@@ -710,6 +718,8 @@ def collate_fn(examples):
710718 accelerator .save_state (save_path )
711719 logger .info (f"Saved state to { save_path } " )
712720
721+ if global_step % args .validation_steps == 0 :
722+ if accelerator .is_main_process :
713723 if args .validation_prompt :
714724 pipeline = StableDiffusionPipeline .from_pretrained (
715725 args .pretrained_model_name_or_path ,
You can’t perform that action at this time.
0 commit comments