@@ -56,10 +56,13 @@ def __init__(
5656 self .pop_size = pop_size
5757 self .n_objs = n_objs
5858 device = torch .get_default_device () if device is None else device
59+
5960 # check
6061 assert lb .shape == ub .shape and lb .ndim == 1 and ub .ndim == 1
6162 assert lb .dtype == ub .dtype and lb .device == ub .device
63+
6264 self .dim = lb .size (0 )
65+
6366 # write to self
6467 self .lb = lb .unsqueeze (0 ).to (device = device )
6568 self .ub = ub .unsqueeze (0 ).to (device = device )
@@ -68,8 +71,6 @@ def __init__(
6871 self .fr = Parameter (fr )
6972 self .max_gen = Parameter (max_gen )
7073
71- self .rv_adapt_every = Mutable (torch .max (torch .round (1 / self .fr ), torch .tensor (1.0 )))
72-
7374 self .selection = selection_op
7475 self .mutation = mutation_op
7576 self .crossover = crossover_op
@@ -80,31 +81,38 @@ def __init__(
8081 self .mutation = polynomial_mutation
8182 if self .crossover is None :
8283 self .crossover = simulated_binary
83- sampling , _ = uniform_sampling (self .pop_size , self .n_objs )
8484
85+ sampling , _ = uniform_sampling (self .pop_size , self .n_objs )
8586 v = sampling .to (device = device )
8687
8788 v0 = v .clone ()
8889 self .pop_size = v .size (0 )
90+
8991 length = self .ub - self .lb
9092 population = torch .rand (self .pop_size , self .dim , device = device )
9193 population = length * population + self .lb
94+
9295 v1 = torch .rand (self .pop_size , self .n_objs , device = device )
9396 v = torch .cat ([v , v1 ], dim = 0 )
9497
9598 self .pop = Mutable (population )
96- self .fit = Mutable (torch .empty ((self .pop_size , self .n_objs ), device = device ).fill_ (torch .inf ))
97- self .reference_vector = Mutable (v )
98- self .init_v = v0
99- self .gen = Mutable (torch .tensor (0 , dtype = int , device = device ))
99+ self .fit = Mutable (torch .full ((self .pop_size , self .n_objs ), torch .inf , device = device ))
100+ self .reference_vector = Mutable (v .clone ())
101+ self .init_v = v0 .clone ()
102+
103+ self .gen = Mutable (torch .tensor (0 , dtype = torch .long , device = device ))
104+ self .rv_adapt_every = Mutable (torch .tensor (1 , dtype = torch .long , device = device ))
100105
101106 def init_step (self ):
102107 """
103108 Perform the initialization step of the workflow.
104109
105110 Calls the `init_step` of the algorithm if overwritten; otherwise, its `step` method will be invoked.
106111 """
107- self .rv_adapt_every = torch .max (torch .round (1 / self .fr ), torch .tensor (1.0 ))
112+ rv_adapt_every = torch .round (1.0 / self .fr ).to (device = self .pop .device )
113+ rv_adapt_every = torch .clamp (rv_adapt_every , min = 1 )
114+ self .rv_adapt_every = rv_adapt_every .to (dtype = torch .long )
115+
108116 self .fit = self .evaluate (self .pop )
109117
110118 def _rv_adaptation (self , pop_obj : torch .Tensor ):
@@ -118,13 +126,31 @@ def _no_rv_adaptation(self, pop_obj: torch.Tensor):
118126 def _mating_pool (self ):
119127 valid_mask = ~ torch .isnan (self .pop ).all (dim = 1 )
120128 num_valid = torch .sum (valid_mask , dtype = torch .int32 )
129+
121130 mating_pool = randint (0 , num_valid , (self .pop_size ,), device = self .pop .device )
122- sorted_indices = torch .where (valid_mask , torch .arange (self .pop .size (0 ), device = self .pop .device ), torch .iinfo (torch .int32 ).max )
131+
132+ sorted_indices = torch .where (
133+ valid_mask ,
134+ torch .arange (self .pop .size (0 ), device = self .pop .device ),
135+ torch .iinfo (torch .int32 ).max ,
136+ )
123137 sorted_indices = torch .argsort (sorted_indices , stable = True )
124138 pop = self .pop [sorted_indices [mating_pool ]]
125139 return pop
126140
127141 def _rv_regeneration (self , pop_obj : torch .Tensor , v : torch .Tensor ):
142+ valid_mask = ~ torch .isnan (pop_obj ).all (dim = 1 )
143+ valid_obj = pop_obj [valid_mask ]
144+
145+ if valid_obj .size (0 ) == 0 :
146+ return v .clone ()
147+
148+ rank = non_dominate_rank (valid_obj )
149+ pop_obj = valid_obj [rank == 0 ]
150+
151+ if pop_obj .size (0 ) == 0 :
152+ return v .clone ()
153+
128154 pop_obj = pop_obj - nanmin (pop_obj , dim = 0 ).values
129155 cosine = F .cosine_similarity (pop_obj .unsqueeze (1 ), v .unsqueeze (0 ), dim = - 1 )
130156
@@ -133,29 +159,37 @@ def _rv_regeneration(self, pop_obj: torch.Tensor, v: torch.Tensor):
133159 associate = input_tensor .max (dim = 1 , keepdim = False ).indices
134160 associate = torch .where (input_tensor [:, 0 ] == - torch .inf , - 1 , associate )
135161
136- invalid = torch .sum ((associate .unsqueeze (1 ) == torch .arange (v .size (0 ), device = pop_obj .device )), dim = 0 )
162+ invalid = torch .sum (
163+ associate .unsqueeze (1 ) == torch .arange (v .size (0 ), device = pop_obj .device ),
164+ dim = 0 ,
165+ )
137166 rand = torch .rand ((v .size (0 ), v .size (1 )), device = pop_obj .device ) * nanmax (pop_obj , dim = 0 ).values
138167 new_v = torch .where ((invalid == 0 ).unsqueeze (1 ), rand , v )
139168
140169 return new_v
141170
142171 def _batch_truncation (self , pop : torch .Tensor , obj : torch .Tensor ):
143- n = pop .size (0 ) // 2
144- cosine = F .cosine_similarity (obj .unsqueeze (1 ), obj .unsqueeze (0 ), dim = - 1 )
145- not_all_nan_rows = ~ torch .isnan (cosine ).all (dim = 1 )
146- mask = torch .eye (cosine .size (0 ), dtype = torch .bool , device = pop .device ) & not_all_nan_rows .unsqueeze (1 )
147- cosine = torch .where (mask , 0 , cosine )
172+ valid_mask = ~ torch .isnan (obj ).all (dim = 1 )
173+ valid_pop = pop [valid_mask ]
174+ valid_obj = obj [valid_mask ]
148175
149- sorted_values , _ = torch .sort (- cosine , dim = 1 )
150- sorted_values = torch .where (torch .isnan (sorted_values [:, 0 ]), - torch .inf , sorted_values [:, 0 ])
151- rank = torch .argsort (sorted_values )
176+ if valid_obj .size (0 ) == 0 :
177+ new_pop = torch .full_like (pop , torch .nan )
178+ new_obj = torch .full_like (obj , torch .nan )
179+ return new_pop , new_obj
152180
153- mask = torch .ones (rank .size (0 ), dtype = torch .bool , device = pop .device )
154- mask = torch .where (torch .arange (rank .size (0 ), device = pop .device ) < n , torch .tensor (0 , dtype = torch .bool , device = pop .device ), mask )
155- mask = mask .unsqueeze (1 )
181+ rank = non_dominate_rank (valid_obj )
182+ nd_mask = rank == 0
156183
157- new_pop = torch .where (mask , pop , torch .nan )
158- new_obj = torch .where (mask , obj , torch .nan )
184+ nd_pop = valid_pop [nd_mask ]
185+ nd_obj = valid_obj [nd_mask ]
186+
187+ new_pop = torch .full_like (pop , torch .nan )
188+ new_obj = torch .full_like (obj , torch .nan )
189+
190+ keep_n = min (nd_pop .size (0 ), pop .size (0 ))
191+ new_pop [:keep_n ] = nd_pop [:keep_n ]
192+ new_obj [:keep_n ] = nd_obj [:keep_n ]
159193
160194 return new_pop , new_obj
161195
@@ -164,43 +198,38 @@ def _no_batch_truncation(self, pop: torch.Tensor, obj: torch.Tensor):
164198
165199 def _update_pop_and_rv (self , survivor : torch .Tensor , survivor_fit : torch .Tensor ):
166200 v_regen = self ._rv_regeneration (survivor_fit , self .reference_vector [self .pop_size :])
167- if torch .compiler .is_compiling ():
168- v_adapt = torch .cond (
169- self .gen % self .rv_adapt_every == 0 , self ._rv_adaptation , self ._no_rv_adaptation , (survivor_fit ,)
170- )
171- self .pop , self .fit = torch .cond (self .gen == self .max_gen , self ._batch_truncation , self ._no_batch_truncation , (survivor , survivor_fit ))
201+
202+ if (self .gen % self .rv_adapt_every ) == 0 :
203+ v_adapt = self ._rv_adaptation (survivor_fit )
172204 else :
173- if self .gen % self .rv_adapt_every == 0 :
174- v_adapt = self ._rv_adaptation (survivor_fit )
175- else :
176- v_adapt = self ._no_rv_adaptation (survivor_fit )
177- if self .gen == self .max_gen :
178- self .pop , self .fit = self ._batch_truncation (survivor , survivor_fit )
179- else :
180- self .pop , self .fit = self ._no_batch_truncation (survivor , survivor_fit )
205+ v_adapt = self ._no_rv_adaptation (survivor_fit )
206+
207+ if self .gen == self .max_gen :
208+ self .pop , self .fit = self ._batch_truncation (survivor , survivor_fit )
209+ else :
210+ self .pop , self .fit = self ._no_batch_truncation (survivor , survivor_fit )
211+
181212 self .reference_vector = torch .cat ([v_adapt , v_regen ], dim = 0 )
182213
183214 def step (self ):
184215 """Perform a single optimization step."""
185216
186- self .gen = self .gen + 1
217+ self .gen = self .gen + torch .tensor (1 , dtype = self .gen .dtype , device = self .gen .device )
218+
187219 pop = self ._mating_pool ()
188220 crossovered = self .crossover (pop )
189221 offspring = self .mutation (crossovered , self .lb , self .ub )
190222 offspring = clamp (offspring , self .lb , self .ub )
191223 off_fit = self .evaluate (offspring )
224+
192225 merge_pop = torch .cat ([self .pop , offspring ], dim = 0 )
193226 merge_fit = torch .cat ([self .fit , off_fit ], dim = 0 )
194227
195- rank = non_dominate_rank (merge_fit )
196- merge_fit = torch .where (rank .unsqueeze (1 ) == 0 , merge_fit , torch .nan )
197- merge_pop = torch .where (rank .unsqueeze (1 ) == 0 , merge_pop , torch .nan )
198-
199228 survivor , survivor_fit = self .selection (
200229 merge_pop ,
201230 merge_fit ,
202231 self .reference_vector ,
203232 (self .gen / self .max_gen ) ** self .alpha ,
204233 )
205234
206- self ._update_pop_and_rv (survivor , survivor_fit )
235+ self ._update_pop_and_rv (survivor , survivor_fit )
0 commit comments