Skip to content

Commit 91bf8eb

Browse files
committed
Add multivariate rational interpolation
1 parent 46c3dd4 commit 91bf8eb

File tree

12 files changed

+813
-0
lines changed

12 files changed

+813
-0
lines changed

src/AlgebraicSolving.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ include("algorithms/param-curve.jl")
1818
include("algorithms/hilbert.jl")
1919
#= siggb =#
2020
include("siggb/siggb.jl")
21+
#= progress =#
22+
include("progress/main.jl")
23+
#= interp =#
24+
include("interp/main.jl")
2125
#= examples =#
2226
include("examples/katsura.jl")
2327
include("examples/cyclic.jl")

src/interp/cuyt_lee.jl

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
mutable struct Atomic{T}
2+
@atomic x::T
3+
end
4+
5+
struct CuytLeeError <: Exception
6+
msg::String
7+
end
8+
9+
Base.show(io::IO, e::CuytLeeError) = print(io, "CuytLeeError: ", e.msg)
10+
11+
@doc Markdown.doc"""
12+
_random_point(n::Int)
13+
14+
Generates a random point in $\mathbb{Z}^n$ with coordinates between 1 and 99.
15+
16+
**Note**: This is an internal function.
17+
"""
18+
19+
function _random_point(n::Int)::Vector{Int}
20+
map(a -> 1 + abs(a) % 99, rand(Int, n))
21+
end
22+
23+
@doc Markdown.doc"""
24+
_estimate_total_degree(R::QQMPolyRing, bb::Function; samples=5, show_progress=false)
25+
26+
Estimate the total degree of the rational function corresponding to the black box function `bb` by evaluating it at random points and performing univariate Thiele interpolation.
27+
28+
**Note**: This is an internal function.
29+
"""
30+
31+
function _estimate_total_degree(
32+
R::QQMPolyRing,
33+
bb::Function;
34+
samples::Int=5,
35+
show_progress::Bool=false
36+
)::Tuple{Int,Int}
37+
t = length(gens(R))
38+
@assert t > 0
39+
R_z, _ = polynomial_ring(QQ, :z)
40+
total_degree_counts = Dict{Tuple{Int,Int},Int}()
41+
total = 0
42+
while total < samples
43+
x = _random_point(t)
44+
try
45+
f = thiele(R_z, k -> bb(k .* x); show_progress=show_progress, offset=1)
46+
total_degree = (degree(denominator(f)), degree(numerator(f)))
47+
total_degree_counts[total_degree] = get(total_degree_counts, total_degree, 0) + 1
48+
total += 1
49+
if findmax(total_degree_counts)[1] > samples ÷ 2
50+
break
51+
end
52+
catch e
53+
if isa(e, ThieleError)
54+
continue
55+
else
56+
rethrow(e)
57+
end
58+
end
59+
end
60+
return findmax(total_degree_counts)[2]
61+
end
62+
63+
@doc Markdown.doc"""
64+
_homogenize(f::QQMPolyRingElem, d::Int)
65+
66+
Homogenize the given polynomial `f` to total degree `d` by multiplying each term by the appropriate power of the first variable.
67+
68+
**Note**: This is an internal function.
69+
"""
70+
71+
function _homogenize(
72+
f::QQMPolyRingElem,
73+
d::Int
74+
)::QQMPolyRingElem
75+
R = parent(f)
76+
C = MPolyBuildCtx(R)
77+
for i in 1:f.length
78+
exp = exponent_vector(f, i)
79+
@assert exp[1] == 0
80+
total_deg = sum(exp)
81+
exp[1] += d - total_deg
82+
@assert exp[1] >= 0
83+
push_term!(C, coeff(f, i), exp)
84+
end
85+
return finish(C)
86+
end
87+
88+
@doc Markdown.doc"""
89+
cuyt_lee_shifted(R::QQMPolyRing, bb::Function; retry=10, nr_thrds=1, show_progress=false, desc="Multivariate rational interpolation")
90+
91+
Compute the multivariate rational function corresponding to the black box function `bb` using Cuyt and Lee's interpolation algorithm.
92+
This function assumes that a random shift has already been applied to the input of the black box function, and does not perform any retries if the interpolation fails.
93+
94+
**Note**: This is an internal function. For a user-facing function that automatically applies random shifts, see `cuyt_lee`.
95+
"""
96+
97+
function cuyt_lee_shifted(
98+
R::QQMPolyRing,
99+
bb::Function;
100+
retry::Int=10,
101+
nr_thrds::Int=1,
102+
show_progress::Bool=false,
103+
desc::String="Multivariate rational interpolation"
104+
)::FracFieldElem{QQMPolyRingElem}
105+
# https://arxiv.org/pdf/1608.01902
106+
t = length(gens(R))
107+
if t == 0
108+
return R(bb(Vector{QQFieldElem}())) // one(R)
109+
end
110+
R_z, _ = polynomial_ring(QQ, :z)
111+
d_den, d_num = _estimate_total_degree(R, bb; show_progress=show_progress)
112+
d = [0; fill(max(d_den, d_num), t - 1)...]
113+
x = Vector{Vector{ZZRingElem}}()
114+
coeffs_den = []
115+
coeffs_num = []
116+
data_lock = ReentrantLock()
117+
prog = ProgressBar(total=prod(d .+ 1); desc=desc, enabled=show_progress)
118+
update!(prog, 0)
119+
function populate(cur::Vector{ZZRingElem}, dim::Int; num_threads::Int=1, offset=1)::Bool
120+
if dim > t
121+
f = nothing
122+
try
123+
f = thiele(R_z, k -> bb(k .* cur); retry=retry, show_progress=show_progress, offset=offset)
124+
catch e
125+
if isa(e, BoundsError) || isa(e, ThieleError)
126+
return false
127+
else
128+
rethrow(e)
129+
end
130+
end
131+
if degree(denominator(f)) != d_den || degree(numerator(f)) != d_num
132+
return false
133+
end
134+
c = constant_coefficient(denominator(f))
135+
if c == 0
136+
return false
137+
end
138+
lock(data_lock) do
139+
push!(x, copy(cur))
140+
push!(coeffs_den, collect(coefficients(denominator(f))) ./ c)
141+
push!(coeffs_num, collect(coefficients(numerator(f))) ./ c)
142+
update!(prog, length(x))
143+
end
144+
return true
145+
end
146+
total = Atomic(0)
147+
failures = Atomic(0)
148+
i = 1
149+
while total.x < d[dim] + 1
150+
num_threads_chunk = min(num_threads, d[dim] + 1 - total.x)
151+
Threads.@threads for j in 0:num_threads_chunk-1
152+
if populate([cur; ZZ(i + j)], dim + 1; num_threads=1, offset=max(offset, j + 1))
153+
@atomic total.x += 1
154+
@atomic failures.x = 0
155+
else
156+
@atomic failures.x += 1
157+
end
158+
end
159+
i += num_threads_chunk
160+
if failures.x >= retry
161+
return false
162+
end
163+
end
164+
return true
165+
end
166+
res = populate([ZZ(1)], 2; num_threads=nr_thrds)
167+
if !res
168+
finish!(prog)
169+
throw(CuytLeeError("Failed to collect enough data points for interpolation. This could happen if the black box function is singular at zero, or if the expected total degree is incorrect."))
170+
end
171+
perm = sortperm(x)
172+
x = x[perm]
173+
coeffs_den = coeffs_den[perm]
174+
coeffs_num = coeffs_num[perm]
175+
# We interpolate the denominator and numerator separately
176+
den = zero(R)
177+
num = zero(R)
178+
for i in 0:d_den
179+
y = [coeffs_den[j][i+1] for j in 1:length(x)]
180+
den += _homogenize(newton(R, x, y, d), i)
181+
end
182+
for i in 0:d_num
183+
y = [coeffs_num[j][i+1] for j in 1:length(x)]
184+
num += _homogenize(newton(R, x, y, d), i)
185+
end
186+
finish!(prog)
187+
return num // den
188+
end
189+
190+
@doc Markdown.doc"""
191+
cuyt_lee_with_shift(R::QQMPolyRing, bb::Function, shift::Vector{Int}; retry=10, nr_thrds=1, show_progress=false, desc="Multivariate rational interpolation")
192+
193+
Compute the multivariate rational function corresponding to the black box function `bb` using Cuyt and Lee's interpolation algorithm,
194+
with a given shift applied to the input of the black box function.
195+
196+
**Note**: This is an internal function. For a user-facing function that automatically applies random shifts, see `cuyt_lee`.
197+
"""
198+
199+
function cuyt_lee_with_shift(
200+
R::QQMPolyRing,
201+
bb::Function,
202+
shift::Vector{Int};
203+
retry::Int=10,
204+
nr_thrds::Int=1,
205+
show_progress::Bool=false,
206+
desc::String="Multivariate rational interpolation"
207+
)::FracFieldElem{QQMPolyRingElem}
208+
t = length(gens(R))
209+
if t == 0
210+
return R(bb(Vector{QQFieldElem}())) // one(R)
211+
end
212+
f_shifted = cuyt_lee_shifted(R, z -> bb(z .+ shift); retry=retry, nr_thrds=nr_thrds, show_progress=show_progress, desc=desc)
213+
x = gens(R) .- shift
214+
num = evaluate(numerator(f_shifted), x)
215+
den = evaluate(denominator(f_shifted), x)
216+
return num // den
217+
end
218+
219+
@doc Markdown.doc"""
220+
cuyt_lee(R::QQMPolyRing, bb::Function; initial_shift=_random_point(length(gens(R))), retry=10, nr_thrds=1, show_progress=false, desc="Multivariate rational interpolation")
221+
222+
Compute the multivariate rational function corresponding to the black box function `bb` using Cuyt and Lee's interpolation algorithm.
223+
224+
# Arguments
225+
- `R::QQMPolyRing`: the multivariate polynomial ring over the rationals.
226+
- `bb::Function`: a black box function that takes a vector of rational numbers as input and returns a rational number as output.
227+
- `initial_shift::Vector{Int}=_random_point(length(gens(R)))`: the initial shift to use for the interpolation.
228+
- `retry::Int=10`: the maximum number of consecutive failures allowed when evaluating the black box function or interpolating the points.
229+
- `nr_thrds::Int=1`: the number of threads to use when evaluating the black box function.
230+
- `show_progress::Bool=false`: whether to show a progress bar while collecting points.
231+
- `desc::String="Multivariate rational interpolation"`: the description to show in the progress bar.
232+
"""
233+
234+
function cuyt_lee(
235+
R::QQMPolyRing,
236+
bb::Function;
237+
initial_shift=_random_point(length(gens(R))),
238+
retry::Int=10,
239+
nr_thrds::Int=1,
240+
show_progress::Bool=false,
241+
desc::String="Multivariate rational interpolation"
242+
)::FracFieldElem{QQMPolyRingElem}
243+
t = length(gens(R))
244+
if t == 0
245+
return R(bb(Vector{QQFieldElem}())) // one(R)
246+
end
247+
shift = initial_shift
248+
for i in 1:retry
249+
try
250+
return cuyt_lee_with_shift(R, bb, shift; retry=retry, nr_thrds=nr_thrds, show_progress=show_progress, desc=desc)
251+
catch e
252+
if isa(e, CuytLeeError)
253+
if show_progress
254+
@warn "Interpolation failed, retrying with a different shift... Retries left: $(retry - i)"
255+
end
256+
shift = _random_point(t)
257+
else
258+
rethrow(e)
259+
end
260+
end
261+
end
262+
throw(CuytLeeError("Interpolation failed after maximum number of retries."))
263+
end

src/interp/main.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
module Interpolation
2+
3+
using Markdown
4+
using Nemo
5+
import Nemo.Generic: FracFieldElem
6+
using ..Progress
7+
8+
export thiele, newton, cuyt_lee
9+
10+
# Interpolation algorithms
11+
include("thiele.jl")
12+
include("newton.jl")
13+
include("cuyt_lee.jl")
14+
15+
# Applications
16+
include("resultant.jl")
17+
18+
end # module Interpolation

0 commit comments

Comments
 (0)