Skip to content

Commit f35d893

Browse files
committed
Fix the RVEA and RVEAa implementation.
Avoid the use of torch.cond in the RVEA and RVEAa implementation. This is because torch.cond does not support pytree outputs, so we need to stack the outputs into a single tensor. This is a workaround to avoid the use of torch.cond in the RVEA and RVEAa implementation. This is a workaround to avoid the use of torch.cond in the RVEA and RVEAa implementation.
1 parent 59408dc commit f35d893

3 files changed

Lines changed: 101 additions & 63 deletions

File tree

src/evox/algorithms/mo/rvea.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class RVEA(Algorithm):
2323
[2] Z. Liang, T. Jiang, K. Sun, and R. Cheng, "GPU-accelerated Evolutionary Multiobjective Optimization
2424
Using Tensorized RVEA," in Proceedings of the Genetic and Evolutionary Computation Conference,
2525
ser. GECCO ’24, 2024, pp. 566–575. Available: https://doi.org/10.1145/3638529.3654223
26-
"""
26+
"""
2727

2828
def __init__(
2929
self,
@@ -69,8 +69,6 @@ def __init__(
6969
self.fr = Parameter(fr)
7070
self.max_gen = Parameter(max_gen)
7171

72-
self.rv_adapt_every = Mutable(torch.max(torch.round(1 / self.fr), torch.tensor(1.0)))
73-
7472
self.selection = selection_op
7573
self.mutation = mutation_op
7674
self.crossover = crossover_op
@@ -82,44 +80,52 @@ def __init__(
8280
self.mutation = polynomial_mutation
8381
if self.crossover is None:
8482
self.crossover = simulated_binary
85-
sampling, _ = uniform_sampling(self.pop_size, self.n_objs)
8683

84+
sampling, _ = uniform_sampling(self.pop_size, self.n_objs)
8785
v = sampling.to(device=device)
8886

89-
v0 = v
87+
self.init_v = v.clone()
9088
self.pop_size = v.size(0)
91-
length = ub - lb
89+
90+
length = self.ub - self.lb
9291
population = torch.rand(self.pop_size, self.dim, device=device)
93-
population = length * population + lb
92+
population = length * population + self.lb
9493

9594
self.pop = Mutable(population)
9695
self.fit = Mutable(torch.full((self.pop_size, self.n_objs), torch.inf, device=device))
97-
self.reference_vector = Mutable(v)
98-
self.init_v = v0
99-
self.gen = Mutable(torch.tensor(0, dtype=int, device=device))
96+
self.reference_vector = Mutable(v.clone())
97+
98+
self.gen = Mutable(torch.tensor(0, dtype=torch.long, device=device))
99+
self.rv_adapt_every = Mutable(torch.tensor(1, dtype=torch.long, device=device))
100100

101101
def init_step(self):
102102
"""
103103
Perform the initialization step of the workflow.
104104
105105
Calls the `init_step` of the algorithm if overwritten; otherwise, its `step` method will be invoked.
106106
"""
107-
self.rv_adapt_every = torch.max(torch.round(1 / self.fr), torch.tensor(1.0))
107+
rv_adapt_every = torch.round(1.0 / self.fr).to(device=self.device)
108+
rv_adapt_every = torch.clamp(rv_adapt_every, min=1)
109+
self.rv_adapt_every = rv_adapt_every.to(dtype=torch.long)
110+
108111
self.fit = self.evaluate(self.pop)
109112

110113
def _rv_adaptation(self, pop_obj: torch.Tensor):
111114
max_vals = nanmax(pop_obj, dim=0)[0]
112115
min_vals = nanmin(pop_obj, dim=0)[0]
113116
return self.init_v * (max_vals - min_vals)
114117

115-
def _no_rv_adaptation(self, pop_obj: torch.Tensor):
116-
return self.reference_vector.clone()
117-
118118
def _mating_pool(self):
119119
valid_mask = ~torch.isnan(self.pop).all(dim=1)
120120
num_valid = torch.sum(valid_mask, dtype=torch.int32)
121+
121122
mating_pool = randint(0, num_valid, (self.pop_size,), device=self.pop.device)
122-
sorted_indices = torch.where(valid_mask, torch.arange(self.pop_size, device=self.device), torch.iinfo(torch.int32).max)
123+
124+
sorted_indices = torch.where(
125+
valid_mask,
126+
torch.arange(self.pop_size, device=self.device),
127+
torch.iinfo(torch.int32).max,
128+
)
123129
sorted_indices = torch.argsort(sorted_indices, stable=True)
124130
pop = self.pop[sorted_indices[mating_pool]]
125131
return pop
@@ -128,14 +134,17 @@ def _update_pop_and_rv(self, survivor: torch.Tensor, survivor_fit: torch.Tensor)
128134
self.pop = survivor
129135
self.fit = survivor_fit
130136

131-
self.reference_vector = torch.cond(
132-
self.gen % self.rv_adapt_every == 0, self._rv_adaptation, self._no_rv_adaptation, (survivor_fit,)
133-
)
137+
adapted_rv = self._rv_adaptation(survivor_fit)
138+
keep_rv = self.reference_vector.clone()
139+
140+
adapt_flag = (self.gen % self.rv_adapt_every) == 0
141+
self.reference_vector = torch.where(adapt_flag, adapted_rv, keep_rv)
134142

135143
def step(self):
136144
"""Perform a single optimization step."""
137145

138-
self.gen = self.gen + torch.tensor(1)
146+
self.gen = self.gen + torch.tensor(1, dtype=self.gen.dtype, device=self.device)
147+
139148
pop = self._mating_pool()
140149
crossovered = self.crossover(pop)
141150
offspring = self.mutation(crossovered, self.lb, self.ub)
@@ -151,4 +160,4 @@ def step(self):
151160
(self.gen / self.max_gen) ** self.alpha,
152161
)
153162

154-
self._update_pop_and_rv(survivor, survivor_fit)
163+
self._update_pop_and_rv(survivor, survivor_fit)

src/evox/algorithms/mo/rveaa.py

Lines changed: 71 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,13 @@ def __init__(
5656
self.pop_size = pop_size
5757
self.n_objs = n_objs
5858
device = torch.get_default_device() if device is None else device
59+
5960
# check
6061
assert lb.shape == ub.shape and lb.ndim == 1 and ub.ndim == 1
6162
assert lb.dtype == ub.dtype and lb.device == ub.device
63+
6264
self.dim = lb.size(0)
65+
6366
# write to self
6467
self.lb = lb.unsqueeze(0).to(device=device)
6568
self.ub = ub.unsqueeze(0).to(device=device)
@@ -68,8 +71,6 @@ def __init__(
6871
self.fr = Parameter(fr)
6972
self.max_gen = Parameter(max_gen)
7073

71-
self.rv_adapt_every = Mutable(torch.max(torch.round(1 / self.fr), torch.tensor(1.0)))
72-
7374
self.selection = selection_op
7475
self.mutation = mutation_op
7576
self.crossover = crossover_op
@@ -80,31 +81,38 @@ def __init__(
8081
self.mutation = polynomial_mutation
8182
if self.crossover is None:
8283
self.crossover = simulated_binary
83-
sampling, _ = uniform_sampling(self.pop_size, self.n_objs)
8484

85+
sampling, _ = uniform_sampling(self.pop_size, self.n_objs)
8586
v = sampling.to(device=device)
8687

8788
v0 = v.clone()
8889
self.pop_size = v.size(0)
90+
8991
length = self.ub - self.lb
9092
population = torch.rand(self.pop_size, self.dim, device=device)
9193
population = length * population + self.lb
94+
9295
v1 = torch.rand(self.pop_size, self.n_objs, device=device)
9396
v = torch.cat([v, v1], dim=0)
9497

9598
self.pop = Mutable(population)
96-
self.fit = Mutable(torch.empty((self.pop_size, self.n_objs), device=device).fill_(torch.inf))
97-
self.reference_vector = Mutable(v)
98-
self.init_v = v0
99-
self.gen = Mutable(torch.tensor(0, dtype=int, device=device))
99+
self.fit = Mutable(torch.full((self.pop_size, self.n_objs), torch.inf, device=device))
100+
self.reference_vector = Mutable(v.clone())
101+
self.init_v = v0.clone()
102+
103+
self.gen = Mutable(torch.tensor(0, dtype=torch.long, device=device))
104+
self.rv_adapt_every = Mutable(torch.tensor(1, dtype=torch.long, device=device))
100105

101106
def init_step(self):
102107
"""
103108
Perform the initialization step of the workflow.
104109
105110
Calls the `init_step` of the algorithm if overwritten; otherwise, its `step` method will be invoked.
106111
"""
107-
self.rv_adapt_every = torch.max(torch.round(1 / self.fr), torch.tensor(1.0))
112+
rv_adapt_every = torch.round(1.0 / self.fr).to(device=self.pop.device)
113+
rv_adapt_every = torch.clamp(rv_adapt_every, min=1)
114+
self.rv_adapt_every = rv_adapt_every.to(dtype=torch.long)
115+
108116
self.fit = self.evaluate(self.pop)
109117

110118
def _rv_adaptation(self, pop_obj: torch.Tensor):
@@ -118,13 +126,31 @@ def _no_rv_adaptation(self, pop_obj: torch.Tensor):
118126
def _mating_pool(self):
119127
valid_mask = ~torch.isnan(self.pop).all(dim=1)
120128
num_valid = torch.sum(valid_mask, dtype=torch.int32)
129+
121130
mating_pool = randint(0, num_valid, (self.pop_size,), device=self.pop.device)
122-
sorted_indices = torch.where(valid_mask, torch.arange(self.pop.size(0), device=self.pop.device), torch.iinfo(torch.int32).max)
131+
132+
sorted_indices = torch.where(
133+
valid_mask,
134+
torch.arange(self.pop.size(0), device=self.pop.device),
135+
torch.iinfo(torch.int32).max,
136+
)
123137
sorted_indices = torch.argsort(sorted_indices, stable=True)
124138
pop = self.pop[sorted_indices[mating_pool]]
125139
return pop
126140

127141
def _rv_regeneration(self, pop_obj: torch.Tensor, v: torch.Tensor):
142+
valid_mask = ~torch.isnan(pop_obj).all(dim=1)
143+
valid_obj = pop_obj[valid_mask]
144+
145+
if valid_obj.size(0) == 0:
146+
return v.clone()
147+
148+
rank = non_dominate_rank(valid_obj)
149+
pop_obj = valid_obj[rank == 0]
150+
151+
if pop_obj.size(0) == 0:
152+
return v.clone()
153+
128154
pop_obj = pop_obj - nanmin(pop_obj, dim=0).values
129155
cosine = F.cosine_similarity(pop_obj.unsqueeze(1), v.unsqueeze(0), dim=-1)
130156

@@ -133,29 +159,37 @@ def _rv_regeneration(self, pop_obj: torch.Tensor, v: torch.Tensor):
133159
associate = input_tensor.max(dim=1, keepdim=False).indices
134160
associate = torch.where(input_tensor[:, 0] == -torch.inf, -1, associate)
135161

136-
invalid = torch.sum((associate.unsqueeze(1) == torch.arange(v.size(0), device=pop_obj.device)), dim=0)
162+
invalid = torch.sum(
163+
associate.unsqueeze(1) == torch.arange(v.size(0), device=pop_obj.device),
164+
dim=0,
165+
)
137166
rand = torch.rand((v.size(0), v.size(1)), device=pop_obj.device) * nanmax(pop_obj, dim=0).values
138167
new_v = torch.where((invalid == 0).unsqueeze(1), rand, v)
139168

140169
return new_v
141170

142171
def _batch_truncation(self, pop: torch.Tensor, obj: torch.Tensor):
143-
n = pop.size(0) // 2
144-
cosine = F.cosine_similarity(obj.unsqueeze(1), obj.unsqueeze(0), dim=-1)
145-
not_all_nan_rows = ~torch.isnan(cosine).all(dim=1)
146-
mask = torch.eye(cosine.size(0), dtype=torch.bool, device=pop.device) & not_all_nan_rows.unsqueeze(1)
147-
cosine = torch.where(mask, 0, cosine)
172+
valid_mask = ~torch.isnan(obj).all(dim=1)
173+
valid_pop = pop[valid_mask]
174+
valid_obj = obj[valid_mask]
148175

149-
sorted_values, _ = torch.sort(-cosine, dim=1)
150-
sorted_values = torch.where(torch.isnan(sorted_values[:, 0]), -torch.inf, sorted_values[:, 0])
151-
rank = torch.argsort(sorted_values)
176+
if valid_obj.size(0) == 0:
177+
new_pop = torch.full_like(pop, torch.nan)
178+
new_obj = torch.full_like(obj, torch.nan)
179+
return new_pop, new_obj
152180

153-
mask = torch.ones(rank.size(0), dtype=torch.bool, device=pop.device)
154-
mask = torch.where(torch.arange(rank.size(0), device=pop.device) < n, torch.tensor(0, dtype=torch.bool, device=pop.device), mask)
155-
mask = mask.unsqueeze(1)
181+
rank = non_dominate_rank(valid_obj)
182+
nd_mask = rank == 0
156183

157-
new_pop = torch.where(mask, pop, torch.nan)
158-
new_obj = torch.where(mask, obj, torch.nan)
184+
nd_pop = valid_pop[nd_mask]
185+
nd_obj = valid_obj[nd_mask]
186+
187+
new_pop = torch.full_like(pop, torch.nan)
188+
new_obj = torch.full_like(obj, torch.nan)
189+
190+
keep_n = min(nd_pop.size(0), pop.size(0))
191+
new_pop[:keep_n] = nd_pop[:keep_n]
192+
new_obj[:keep_n] = nd_obj[:keep_n]
159193

160194
return new_pop, new_obj
161195

@@ -164,43 +198,38 @@ def _no_batch_truncation(self, pop: torch.Tensor, obj: torch.Tensor):
164198

165199
def _update_pop_and_rv(self, survivor: torch.Tensor, survivor_fit: torch.Tensor):
166200
v_regen = self._rv_regeneration(survivor_fit, self.reference_vector[self.pop_size :])
167-
if torch.compiler.is_compiling():
168-
v_adapt = torch.cond(
169-
self.gen % self.rv_adapt_every == 0, self._rv_adaptation, self._no_rv_adaptation, (survivor_fit,)
170-
)
171-
self.pop, self.fit = torch.cond(self.gen == self.max_gen, self._batch_truncation, self._no_batch_truncation, (survivor, survivor_fit))
201+
202+
if (self.gen % self.rv_adapt_every) == 0:
203+
v_adapt = self._rv_adaptation(survivor_fit)
172204
else:
173-
if self.gen % self.rv_adapt_every == 0:
174-
v_adapt = self._rv_adaptation(survivor_fit)
175-
else:
176-
v_adapt = self._no_rv_adaptation(survivor_fit)
177-
if self.gen == self.max_gen:
178-
self.pop, self.fit = self._batch_truncation(survivor, survivor_fit)
179-
else:
180-
self.pop, self.fit = self._no_batch_truncation(survivor, survivor_fit)
205+
v_adapt = self._no_rv_adaptation(survivor_fit)
206+
207+
if self.gen == self.max_gen:
208+
self.pop, self.fit = self._batch_truncation(survivor, survivor_fit)
209+
else:
210+
self.pop, self.fit = self._no_batch_truncation(survivor, survivor_fit)
211+
181212
self.reference_vector = torch.cat([v_adapt, v_regen], dim=0)
182213

183214
def step(self):
184215
"""Perform a single optimization step."""
185216

186-
self.gen = self.gen + 1
217+
self.gen = self.gen + torch.tensor(1, dtype=self.gen.dtype, device=self.gen.device)
218+
187219
pop = self._mating_pool()
188220
crossovered = self.crossover(pop)
189221
offspring = self.mutation(crossovered, self.lb, self.ub)
190222
offspring = clamp(offspring, self.lb, self.ub)
191223
off_fit = self.evaluate(offspring)
224+
192225
merge_pop = torch.cat([self.pop, offspring], dim=0)
193226
merge_fit = torch.cat([self.fit, off_fit], dim=0)
194227

195-
rank = non_dominate_rank(merge_fit)
196-
merge_fit = torch.where(rank.unsqueeze(1) == 0, merge_fit, torch.nan)
197-
merge_pop = torch.where(rank.unsqueeze(1) == 0, merge_pop, torch.nan)
198-
199228
survivor, survivor_fit = self.selection(
200229
merge_pop,
201230
merge_fit,
202231
self.reference_vector,
203232
(self.gen / self.max_gen) ** self.alpha,
204233
)
205234

206-
self._update_pop_and_rv(survivor, survivor_fit)
235+
self._update_pop_and_rv(survivor, survivor_fit)

src/evox/operators/selection/rvea_selection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def apd_fn(
2323
:return: A tensor containing the APD values for each solution.
2424
"""
2525
selected_z = torch.gather(z, 0, torch.relu(x))
26-
left = (1 + obj.size(1) * theta * selected_z) / y[None, :]
26+
left = 1 + obj.size(1) * theta * selected_z / y[None, :]
2727
norm_obj = torch.linalg.vector_norm(obj, dim=1)
2828
right = norm_obj[x]
2929
return left * right

0 commit comments

Comments
 (0)