Skip to content

Commit 5335103

Browse files
committed
Add multivariate rational interpolation
1 parent 46c3dd4 commit 5335103

File tree

13 files changed

+805
-0
lines changed

13 files changed

+805
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
1010
Nemo = "2edaba10-b0f1-5616-af89-8c11ac63239a"
1111
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1212
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
13+
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
1314
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1415
msolve_jll = "6d01cc9a-e8f6-580e-8c54-544227e08205"
1516

@@ -26,6 +27,7 @@ Markdown = "1.6"
2627
Nemo = "0.43, 0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.50, 0.51, 0.52, 0.53, 0.54"
2728
Printf = "1.6"
2829
Random = "1.6"
30+
REPL = "1.6"
2931
StaticArrays = "1"
3032
Test = "1.6"
3133
julia = "1.10"

src/AlgebraicSolving.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ include("algorithms/param-curve.jl")
1818
include("algorithms/hilbert.jl")
1919
#= siggb =#
2020
include("siggb/siggb.jl")
21+
#= progress =#
22+
include("progress/main.jl")
23+
#= interp =#
24+
include("interp/main.jl")
2125
#= examples =#
2226
include("examples/katsura.jl")
2327
include("examples/cyclic.jl")

src/interp/cuyt_lee.jl

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
mutable struct Atomic{T}
2+
@atomic x::T
3+
end
4+
5+
struct CuytLeeError <: Exception
6+
msg::String
7+
end
8+
9+
Base.show(io::IO, e::CuytLeeError) = print(io, "CuytLeeError: ", e.msg)
10+
11+
@doc Markdown.doc"""
12+
_random_point(n::Int)
13+
14+
Generates a random point in $\mathbb{Z}^n$ with coordinates between 1 and 99.
15+
16+
**Note**: This is an internal function.
17+
"""
18+
function _random_point(n::Int)::Vector{Int}
19+
map(a -> 1 + abs(a) % 99, rand(Int, n))
20+
end
21+
22+
@doc Markdown.doc"""
23+
_estimate_total_degree(R::QQMPolyRing, bb::Function; samples=5, show_progress=false)
24+
25+
Estimate the total degree of the rational function corresponding to the black box function `bb` by evaluating it at random points and performing univariate Thiele interpolation.
26+
27+
**Note**: This is an internal function.
28+
"""
29+
function _estimate_total_degree(
30+
R::QQMPolyRing,
31+
bb::Function;
32+
samples::Int=5,
33+
show_progress::Bool=false
34+
)::Tuple{Int,Int}
35+
t = length(gens(R))
36+
@assert t > 0
37+
R_z, _ = polynomial_ring(QQ, :z)
38+
total_degree_counts = Dict{Tuple{Int,Int},Int}()
39+
total = 0
40+
while total < samples
41+
x = _random_point(t)
42+
try
43+
f = thiele(R_z, k -> bb(k .* x); show_progress=show_progress, offset=1)
44+
total_degree = (degree(denominator(f)), degree(numerator(f)))
45+
total_degree_counts[total_degree] = get(total_degree_counts, total_degree, 0) + 1
46+
total += 1
47+
if findmax(total_degree_counts)[1] > samples ÷ 2
48+
break
49+
end
50+
catch e
51+
if isa(e, ThieleError)
52+
continue
53+
else
54+
rethrow(e)
55+
end
56+
end
57+
end
58+
return findmax(total_degree_counts)[2]
59+
end
60+
61+
@doc Markdown.doc"""
62+
_homogenize(f::QQMPolyRingElem, d::Int)
63+
64+
Homogenize the given polynomial `f` to total degree `d` by multiplying each term by the appropriate power of the first variable.
65+
66+
**Note**: This is an internal function.
67+
"""
68+
function _homogenize(
69+
f::QQMPolyRingElem,
70+
d::Int
71+
)::QQMPolyRingElem
72+
R = parent(f)
73+
C = MPolyBuildCtx(R)
74+
for i in 1:f.length
75+
exp = exponent_vector(f, i)
76+
@assert exp[1] == 0
77+
total_deg = sum(exp)
78+
exp[1] += d - total_deg
79+
@assert exp[1] >= 0
80+
push_term!(C, coeff(f, i), exp)
81+
end
82+
return finish(C)
83+
end
84+
85+
@doc Markdown.doc"""
86+
cuyt_lee_shifted(R::QQMPolyRing, bb::Function; retry=10, nr_thrds=1, show_progress=false, desc="Multivariate rational interpolation")
87+
88+
Compute the multivariate rational function corresponding to the black box function `bb` using Cuyt and Lee's interpolation algorithm.
89+
This function assumes that a random shift has already been applied to the input of the black box function, and does not perform any retries if the interpolation fails.
90+
91+
**Note**: This is an internal function. For a user-facing function that automatically applies random shifts, see `cuyt_lee`.
92+
"""
93+
function cuyt_lee_shifted(
94+
R::QQMPolyRing,
95+
bb::Function;
96+
retry::Int=10,
97+
nr_thrds::Int=1,
98+
show_progress::Bool=false,
99+
desc::String="Multivariate rational interpolation"
100+
)::FracFieldElem{QQMPolyRingElem}
101+
# https://arxiv.org/pdf/1608.01902
102+
t = length(gens(R))
103+
if t == 0
104+
return R(bb(Vector{QQFieldElem}())) // one(R)
105+
end
106+
R_z, _ = polynomial_ring(QQ, :z)
107+
d_den, d_num = _estimate_total_degree(R, bb; show_progress=show_progress)
108+
d = [0; fill(max(d_den, d_num), t - 1)...]
109+
x = Vector{Vector{ZZRingElem}}()
110+
coeffs_den = []
111+
coeffs_num = []
112+
data_lock = ReentrantLock()
113+
prog = ProgressBar(total=prod(d .+ 1); desc=desc, enabled=show_progress)
114+
update!(prog, 0)
115+
function populate(cur::Vector{ZZRingElem}, dim::Int; num_threads::Int=1, offset=1)::Bool
116+
if dim > t
117+
f = nothing
118+
try
119+
f = thiele(R_z, k -> bb(k .* cur); retry=retry, show_progress=show_progress, offset=offset)
120+
catch e
121+
if isa(e, BoundsError) || isa(e, ThieleError)
122+
return false
123+
else
124+
rethrow(e)
125+
end
126+
end
127+
if degree(denominator(f)) != d_den || degree(numerator(f)) != d_num
128+
return false
129+
end
130+
c = constant_coefficient(denominator(f))
131+
if c == 0
132+
return false
133+
end
134+
lock(data_lock) do
135+
push!(x, copy(cur))
136+
push!(coeffs_den, collect(coefficients(denominator(f))) ./ c)
137+
push!(coeffs_num, collect(coefficients(numerator(f))) ./ c)
138+
update!(prog, length(x))
139+
end
140+
return true
141+
end
142+
total = Atomic(0)
143+
failures = Atomic(0)
144+
i = 1
145+
while total.x < d[dim] + 1
146+
num_threads_chunk = min(num_threads, d[dim] + 1 - total.x)
147+
Threads.@threads for j in 0:num_threads_chunk-1
148+
if populate([cur; ZZ(i + j)], dim + 1; num_threads=1, offset=max(offset, j + 1))
149+
@atomic total.x += 1
150+
@atomic failures.x = 0
151+
else
152+
@atomic failures.x += 1
153+
end
154+
end
155+
i += num_threads_chunk
156+
if failures.x >= retry
157+
return false
158+
end
159+
end
160+
return true
161+
end
162+
res = populate([ZZ(1)], 2; num_threads=nr_thrds)
163+
if !res
164+
finish!(prog)
165+
throw(CuytLeeError("Failed to collect enough data points for interpolation. This could happen if the black box function is singular at zero, or if the expected total degree is incorrect."))
166+
end
167+
perm = sortperm(x)
168+
x = x[perm]
169+
coeffs_den = coeffs_den[perm]
170+
coeffs_num = coeffs_num[perm]
171+
# We interpolate the denominator and numerator separately
172+
den = zero(R)
173+
num = zero(R)
174+
for i in 0:d_den
175+
y = [coeffs_den[j][i+1] for j in 1:length(x)]
176+
den += _homogenize(newton(R, x, y, d), i)
177+
end
178+
for i in 0:d_num
179+
y = [coeffs_num[j][i+1] for j in 1:length(x)]
180+
num += _homogenize(newton(R, x, y, d), i)
181+
end
182+
finish!(prog)
183+
return num // den
184+
end
185+
186+
@doc Markdown.doc"""
187+
cuyt_lee_with_shift(R::QQMPolyRing, bb::Function, shift::Vector{Int}; retry=10, nr_thrds=1, show_progress=false, desc="Multivariate rational interpolation")
188+
189+
Compute the multivariate rational function corresponding to the black box function `bb` using Cuyt and Lee's interpolation algorithm,
190+
with a given shift applied to the input of the black box function.
191+
192+
**Note**: This is an internal function. For a user-facing function that automatically applies random shifts, see `cuyt_lee`.
193+
"""
194+
function cuyt_lee_with_shift(
195+
R::QQMPolyRing,
196+
bb::Function,
197+
shift::Vector{Int};
198+
retry::Int=10,
199+
nr_thrds::Int=1,
200+
show_progress::Bool=false,
201+
desc::String="Multivariate rational interpolation"
202+
)::FracFieldElem{QQMPolyRingElem}
203+
t = length(gens(R))
204+
if t == 0
205+
return R(bb(Vector{QQFieldElem}())) // one(R)
206+
end
207+
f_shifted = cuyt_lee_shifted(R, z -> bb(z .+ shift); retry=retry, nr_thrds=nr_thrds, show_progress=show_progress, desc=desc)
208+
x = gens(R) .- shift
209+
num = evaluate(numerator(f_shifted), x)
210+
den = evaluate(denominator(f_shifted), x)
211+
return num // den
212+
end
213+
214+
@doc Markdown.doc"""
215+
cuyt_lee(R::QQMPolyRing, bb::Function; initial_shift=_random_point(length(gens(R))), retry=10, nr_thrds=1, show_progress=false, desc="Multivariate rational interpolation")
216+
217+
Compute the multivariate rational function corresponding to the black box function `bb` using Cuyt and Lee's interpolation algorithm.
218+
219+
# Arguments
220+
- `R::QQMPolyRing`: the multivariate polynomial ring over the rationals.
221+
- `bb::Function`: a black box function that takes a vector of rational numbers as input and returns a rational number as output.
222+
- `initial_shift::Vector{Int}=_random_point(length(gens(R)))`: the initial shift to use for the interpolation.
223+
- `retry::Int=10`: the maximum number of consecutive failures allowed when evaluating the black box function or interpolating the points.
224+
- `nr_thrds::Int=1`: the number of threads to use when evaluating the black box function.
225+
- `show_progress::Bool=false`: whether to show a progress bar while collecting points.
226+
- `desc::String="Multivariate rational interpolation"`: the description to show in the progress bar.
227+
"""
228+
function cuyt_lee(
229+
R::QQMPolyRing,
230+
bb::Function;
231+
initial_shift=_random_point(length(gens(R))),
232+
retry::Int=10,
233+
nr_thrds::Int=1,
234+
show_progress::Bool=false,
235+
desc::String="Multivariate rational interpolation"
236+
)::FracFieldElem{QQMPolyRingElem}
237+
t = length(gens(R))
238+
if t == 0
239+
return R(bb(Vector{QQFieldElem}())) // one(R)
240+
end
241+
shift = initial_shift
242+
for i in 1:retry
243+
try
244+
return cuyt_lee_with_shift(R, bb, shift; retry=retry, nr_thrds=nr_thrds, show_progress=show_progress, desc=desc)
245+
catch e
246+
if isa(e, CuytLeeError)
247+
if show_progress
248+
@warn "Interpolation failed, retrying with a different shift... Retries left: $(retry - i)"
249+
end
250+
shift = _random_point(t)
251+
else
252+
rethrow(e)
253+
end
254+
end
255+
end
256+
throw(CuytLeeError("Interpolation failed after maximum number of retries."))
257+
end

src/interp/main.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
module Interpolation
2+
3+
using Markdown
4+
using Nemo
5+
import Nemo.Generic: FracFieldElem
6+
using ..Progress
7+
8+
export thiele, newton, cuyt_lee
9+
10+
# Interpolation algorithms
11+
include("thiele.jl")
12+
include("newton.jl")
13+
include("cuyt_lee.jl")
14+
15+
# Applications
16+
include("resultant.jl")
17+
18+
end # module Interpolation

0 commit comments

Comments
 (0)