@@ -86,7 +86,16 @@ def _shuffle(lis):
8686 return random .sample (lis , len (lis ))
8787
8888
89- def _get_cutout_holes (height , width , min_holes = 8 , max_holes = 32 , min_height = 16 , max_height = 128 , min_width = 16 , max_width = 128 ):
89+ def _get_cutout_holes (
90+ height ,
91+ width ,
92+ min_holes = 8 ,
93+ max_holes = 32 ,
94+ min_height = 16 ,
95+ max_height = 128 ,
96+ min_width = 16 ,
97+ max_width = 128 ,
98+ ):
9099 holes = []
91100 for _n in range (random .randint (min_holes , max_holes )):
92101 hole_height = random .randint (min_height , max_height )
@@ -103,12 +112,13 @@ def _generate_random_mask(image):
103112 mask = zeros_like (image [:1 ])
104113 holes = _get_cutout_holes (mask .shape [1 ], mask .shape [2 ])
105114 for (x1 , y1 , x2 , y2 ) in holes :
106- mask [:, y1 :y2 , x1 :x2 ] = 1.
115+ mask [:, y1 :y2 , x1 :x2 ] = 1.0
107116 if random .uniform (0 , 1 ) < 0.25 :
108- mask .fill_ (1. )
117+ mask .fill_ (1.0 )
109118 masked_image = image * (mask < 0.5 )
110119 return mask , masked_image
111120
121+
112122class PivotalTuningDatasetCapation (Dataset ):
113123 """
114124 A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
@@ -274,7 +284,10 @@ def __getitem__(self, index):
274284 example ["instance_images" ] = self .image_transforms (instance_image )
275285
276286 if self .train_inpainting :
277- example ["instance_masks" ], example ["instance_masked_images" ] = _generate_random_mask (example ["instance_images" ])
287+ (
288+ example ["instance_masks" ],
289+ example ["instance_masked_images" ],
290+ ) = _generate_random_mask (example ["instance_images" ])
278291
279292 if self .use_template :
280293 assert self .token_map is not None
@@ -296,7 +309,7 @@ def __getitem__(self, index):
296309 Image .open (self .mask_path [index % self .num_instance_images ])
297310 )
298311 * 0.5
299- + 0.5
312+ + 1.0
300313 )
301314
302315 if self .h_flip and random .random () > 0.5 :
@@ -321,7 +334,10 @@ def __getitem__(self, index):
321334 class_image = class_image .convert ("RGB" )
322335 example ["class_images" ] = self .image_transforms (class_image )
323336 if self .train_inpainting :
324- example ["class_masks" ], example ["class_masked_images" ] = _generate_random_mask (example ["class_images" ])
337+ (
338+ example ["class_masks" ],
339+ example ["class_masked_images" ],
340+ ) = _generate_random_mask (example ["class_images" ])
325341 example ["class_prompt_ids" ] = self .tokenizer (
326342 self .class_prompt ,
327343 padding = "do_not_pad" ,
0 commit comments