Skip to content

Commit dd38bfb

Browse files
committed
try to make truncation GPU-friendly
1 parent fac25d3 commit dd38bfb

File tree

1 file changed

+33
-16
lines changed

1 file changed

+33
-16
lines changed

src/factorizations/truncation.jl

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,23 @@ function _sort_and_perm(values::SectorVector; by = identity, rev::Bool = false)
191191
return values_sorted, perms
192192
end
193193

194+
function _findtruncvalue_order(values::SectorVector, n::Int; by = identity, rev::Bool = false)
195+
I = sectortype(values)
196+
p = sortperm(parent(values); by, rev)
197+
198+
if FusionStyle(I) isa UniqueFusion # dimensions are all 1
199+
return n <= 0 ? nothing : p[min(n, length(p))]
200+
else
201+
dims = similar(values, Base.promote_op(dim, I))
202+
for (c, v) in pairs(dims)
203+
fill!(v, dim(c))
204+
end
205+
cumulative_dim = cumsum(Base.permute!(parent(dims), p))
206+
k = findlast(<=(n), cumulative_dim)
207+
return isnothing(k) ? k : p[k]
208+
end
209+
end
210+
194211
# findtruncated
195212
# -------------
196213
# Generic fallback
@@ -202,25 +219,25 @@ function MAK.findtruncated(values::SectorVector, ::NoTruncation)
202219
return SectorDict(c => Colon() for c in keys(values))
203220
end
204221

222+
# TruncationByOrder strategy:
223+
# - find the howmany'th value of the input sorted according to the strategy
224+
# - discard everything that is ordered after that value
225+
205226
function MAK.findtruncated(values::SectorVector, strategy::TruncationByOrder)
206-
values_sorted, perms = _sort_and_perm(values; strategy.by, strategy.rev)
207-
inds = MAK.findtruncated_svd(values_sorted, truncrank(strategy.howmany))
208-
return SectorDict(c => perms[c][I] for (c, I) in inds)
209-
end
210-
function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByOrder)
211-
I = keytype(values)
212-
truncdim = SectorDict{I, Int}(c => length(d) for (c, d) in pairs(values))
213-
totaldim = sum(dim(c) * d for (c, d) in truncdim; init = 0)
214-
while totaldim > strategy.howmany
215-
next = _findnexttruncvalue(values, truncdim; strategy.by, strategy.rev)
216-
isnothing(next) && break
217-
_, cmin = next
218-
truncdim[cmin] -= 1
219-
totaldim -= dim(cmin)
220-
truncdim[cmin] == 0 && delete!(truncdim, cmin)
227+
k = _findtruncvalue_order(values, strategy.howmany; strategy.by, strategy.rev)
228+
229+
if isnothing(k)
230+
# discard everything
231+
return SectorDict{sectortype(values), UnitRange{Int}}()
232+
else
233+
val = strategy.by(values[k])
234+
strategy = trunctol(; atol = val, strategy.by, keep_below = !strategy.rev)
235+
return MAK.findtruncated_svd(values, strategy)
221236
end
222-
return SectorDict(c => Base.OneTo(d) for (c, d) in truncdim)
223237
end
238+
# disambiguate
239+
MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByOrder) =
240+
MAK.findtruncated(values, strategy)
224241

225242
function MAK.findtruncated(values::SectorVector, strategy::TruncationByFilter)
226243
return SectorDict(c => findall(strategy.filter, d) for (c, d) in pairs(values))

0 commit comments

Comments
 (0)