@@ -125,8 +125,8 @@ def _update_covariance_matrix(
125125 y = (population - old_mean ) / self .sigma
126126 update = (
127127 (1 - self .c_1 - self .c_mu ) * C
128- + self .c_1 * (p_c @ p_c . T + (1 - h_sigma ) * self .c_c * (2 - self .c_c ) * C )
129- + self .c_mu * (y .T * self .weights ) @ y
128+ + self .c_1 * (p_c . dot ( p_c ) + (1 - h_sigma ) * self .c_c * (2 - self .c_c ) * C )
129+ + self .c_mu * (y .mT * self .weights ) @ y
130130 )
131131 return update
132132
@@ -161,20 +161,22 @@ def _conditional_decomposition(self, iteration: torch.Tensor, C: torch.Tensor):
161161 return B , D , C_invsqrt
162162
163163 def _no_decomposition (self , C : torch .Tensor ):
164- return torch . stack ([ self .B , self .D , self .C_invsqrt ], dim = 0 )
164+ return self .B . clone () , self .D . clone () , self .C_invsqrt . clone ( )
165165
166166 def _decomposition (
167167 self ,
168168 C : torch .Tensor ,
169169 ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
170- C = (C + C .T ) / 2
171- D , B = torch .linalg .eigh (C )
170+ # symC = (C + C.T) / 2 # This will fail to compile since PyTorch tries to in-place modify C
171+ symC = C .clone ()
172+ symC = (symC + symC .mT ) / 2
173+ D , B = torch .linalg .eigh (symC )
172174 D = torch .clamp (D , min = 1e-8 )
173- C_invsqrt = B @ torch .diag (1.0 / torch .sqrt (D )) @ B .T
175+ C_invsqrt = B @ torch .diag (1.0 / torch .sqrt (D )) @ B .mT
174176 D = torch .diag (D )
175177 D = torch .sqrt (D )
176178 D = B @ D
177- return torch . stack ([ B . T , D , C_invsqrt ], dim = 0 )
179+ return B . mT , D , C_invsqrt
178180
179181 def record_step (self ):
180182 return {
0 commit comments