@@ -1131,6 +1131,7 @@ def main(args):
11311131 torch_dtype = torch .float16
11321132 elif args .prior_generation_precision == "bf16" :
11331133 torch_dtype = torch .bfloat16
1134+
11341135 pipeline = FluxPipeline .from_pretrained (
11351136 args .pretrained_model_name_or_path ,
11361137 torch_dtype = torch_dtype ,
@@ -1151,16 +1152,16 @@ def main(args):
11511152 for example in tqdm (
11521153 sample_dataloader , desc = "Generating class images" , disable = not accelerator .is_local_main_process
11531154 ):
1154- images = pipeline (example ["prompt" ]).images
1155+ with torch .autocast (device_type = accelerator .device .type , dtype = torch_dtype ):
1156+ images = pipeline (prompt = example ["prompt" ]).images
11551157
11561158 for i , image in enumerate (images ):
11571159 hash_image = insecure_hashlib .sha1 (image .tobytes ()).hexdigest ()
11581160 image_filename = class_images_dir / f"{ example ['index' ][i ] + cur_class_images } -{ hash_image } .jpg"
11591161 image .save (image_filename )
11601162
11611163 del pipeline
1162- if torch .cuda .is_available ():
1163- torch .cuda .empty_cache ()
1164+ free_memory ()
11641165
11651166 # Handle the repository creation
11661167 if accelerator .is_main_process :
@@ -1728,6 +1729,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17281729 device = accelerator .device ,
17291730 prompt = args .instance_prompt ,
17301731 )
1732+ else :
1733+ prompt_embeds , pooled_prompt_embeds , text_ids = compute_text_embeddings (
1734+ prompts , text_encoders , tokenizers
1735+ )
17311736
17321737 # Convert images to latent space
17331738 if args .cache_latents :
0 commit comments