@@ -191,6 +191,23 @@ function _sort_and_perm(values::SectorVector; by = identity, rev::Bool = false)
191191 return values_sorted, perms
192192end
193193
194+ function _findtruncvalue_order (values:: SectorVector , n:: Int ; by = identity, rev:: Bool = false )
195+ I = sectortype (values)
196+ p = sortperm (parent (values); by, rev)
197+
198+ if FusionStyle (I) isa UniqueFusion # dimensions are all 1
199+ return n <= 0 ? nothing : p[min (n, length (p))]
200+ else
201+ dims = similar (values, Base. promote_op (dim, I))
202+ for (c, v) in pairs (dims)
203+ fill! (v, dim (c))
204+ end
205+ cumulative_dim = cumsum (Base. permute! (parent (dims), p))
206+ k = findlast (<= (n), cumulative_dim)
207+ return isnothing (k) ? k : p[k]
208+ end
209+ end
210+
194211# findtruncated
195212# -------------
196213# Generic fallback
@@ -202,25 +219,25 @@ function MAK.findtruncated(values::SectorVector, ::NoTruncation)
202219 return SectorDict (c => Colon () for c in keys (values))
203220end
204221
222+ # TruncationByOrder strategy:
223+ # - find the howmany'th value of the input sorted according to the strategy
224+ # - discard everything that is ordered after that value
225+
205226function MAK. findtruncated (values:: SectorVector , strategy:: TruncationByOrder )
206- values_sorted, perms = _sort_and_perm (values; strategy. by, strategy. rev)
207- inds = MAK. findtruncated_svd (values_sorted, truncrank (strategy. howmany))
208- return SectorDict (c => perms[c][I] for (c, I) in inds)
209- end
210- function MAK. findtruncated_svd (values:: SectorVector , strategy:: TruncationByOrder )
211- I = keytype (values)
212- truncdim = SectorDict {I, Int} (c => length (d) for (c, d) in pairs (values))
213- totaldim = sum (dim (c) * d for (c, d) in truncdim; init = 0 )
214- while totaldim > strategy. howmany
215- next = _findnexttruncvalue (values, truncdim; strategy. by, strategy. rev)
216- isnothing (next) && break
217- _, cmin = next
218- truncdim[cmin] -= 1
219- totaldim -= dim (cmin)
220- truncdim[cmin] == 0 && delete! (truncdim, cmin)
227+ k = _findtruncvalue_order (values, strategy. howmany; strategy. by, strategy. rev)
228+
229+ if isnothing (k)
230+ # discard everything
231+ return SectorDict {sectortype(values), UnitRange{Int}} ()
232+ else
233+ val = strategy. by (values[k])
234+ strategy = trunctol (; atol = val, strategy. by, keep_below = ! strategy. rev)
235+ return MAK. findtruncated_svd (values, strategy)
221236 end
222- return SectorDict (c => Base. OneTo (d) for (c, d) in truncdim)
223237end
238+ # disambiguate
239+ MAK. findtruncated_svd (values:: SectorVector , strategy:: TruncationByOrder ) =
240+ MAK. findtruncated (values, strategy)
224241
225242function MAK. findtruncated (values:: SectorVector , strategy:: TruncationByFilter )
226243 return SectorDict (c => findall (strategy. filter, d) for (c, d) in pairs (values))
0 commit comments