Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pyomo/contrib/solver/solvers/knitro/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ def _restore_var_values(self) -> None:
var.set_value(self._saved_var_values[id(var)])
StaleFlagManager.mark_all_as_stale()

def _warm_start(self) -> None:
variables = []
for var in self._get_vars():
if var.value is not None:
variables.append(var)
self._engine.set_initial_values(variables)

@abstractmethod
def _presolve(
self, model: BlockData, config: KnitroConfig, timer: HierarchicalTimer
Expand Down
12 changes: 12 additions & 0 deletions pyomo/contrib/solver/solvers/knitro/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ def __init__(
visibility=visibility,
)

self.use_start: bool = self.declare(
"use_start",
ConfigValue(
domain=Bool,
default=False,
doc=(
"If True, KNITRO solver will use the the current values "
"of variables as starting points for the optimization."
),
),
)

self.rebuild_model_on_remove_var: bool = self.declare(
"rebuild_model_on_remove_var",
ConfigValue(
Expand Down
5 changes: 5 additions & 0 deletions pyomo/contrib/solver/solvers/knitro/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def _solve(self, config: KnitroConfig, timer: HierarchicalTimer) -> None:
self._engine.set_options(**config.solver_options)
timer.stop("load_options")

if config.use_start:
timer.start("warm_start")
self._warm_start()
timer.stop("warm_start")

timer.start("solve")
self._engine.solve()
timer.stop("solve")
5 changes: 5 additions & 0 deletions pyomo/contrib/solver/solvers/knitro/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,11 @@ def get_idxs(
idx_map = self.maps[item_type]
return [idx_map[id(item)] for item in items]

def set_initial_values(self, variables: Iterable[VarData]) -> None:
values = [value(var) for var in variables]
idxs = self.get_idxs(VarData, variables)
self.execute(knitro.KN_set_var_primal_init_values, idxs, values)

def get_values(
self,
item_type: type[ItemType],
Expand Down
89 changes: 89 additions & 0 deletions pyomo/contrib/solver/tests/solvers/test_knitro_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,16 @@ def test_default_instantiation(self):
self.assertIsNone(config.timer)
self.assertIsNone(config.threads)
self.assertIsNone(config.time_limit)
self.assertFalse(config.use_start)

def test_custom_instantiation(self):
config = KnitroConfig(description="A description")
config.tee = True
config.use_start = True
self.assertTrue(config.tee)
self.assertEqual(config._description, "A description")
self.assertIsNone(config.time_limit)
self.assertTrue(config.use_start)


@unittest.skipIf(not avail, "KNITRO solver is not available")
Expand Down Expand Up @@ -486,3 +489,89 @@ def test_solve_HS071(self):
self.assertAlmostEqual(pyo.value(m.x[2]), 4.743, 3)
self.assertAlmostEqual(pyo.value(m.x[3]), 3.821, 3)
self.assertAlmostEqual(pyo.value(m.x[4]), 1.379, 3)


@unittest.skipIf(not avail, "KNITRO solver is not available")
class TestKnitroWarmStart(unittest.TestCase):
"""Test cases for KNITRO warm start (use_start) functionality."""

def setUp(self):
self.opt = KnitroDirectSolver()

def test_warm_start_reduces_iterations(self):
"""Test that providing a good starting point reduces the number of iterations."""
m = pyo.ConcreteModel()
m.x = pyo.Var(bounds=(-5, 5))
m.y = pyo.Var(bounds=(-5, 5))
m.obj = pyo.Objective(
expr=(1.0 - m.x) ** 2 + 100.0 * (m.y - m.x**2) ** 2, sense=pyo.minimize
)

m.x.set_value(None)
m.y.set_value(None)
res_no_start = self.opt.solve(m, use_start=False)
iters_no_start = res_no_start.extra_info.number_iters

m.x.set_value(0.9)
m.y.set_value(0.9)
res_with_start = self.opt.solve(m, use_start=True)
iters_with_start = res_with_start.extra_info.number_iters

self.assertAlmostEqual(pyo.value(m.x), 1.0, 3)
self.assertAlmostEqual(pyo.value(m.y), 1.0, 3)

self.assertLessEqual(iters_with_start, iters_no_start)

def test_warm_start_uses_initial_values(self):
"""Test that warm start uses the current variable values."""
m = pyo.ConcreteModel()
m.x = pyo.Var(bounds=(0, 10))
m.y = pyo.Var(bounds=(0, 10))
m.obj = pyo.Objective(expr=(m.x - 3) ** 2 + (m.y - 4) ** 2, sense=pyo.minimize)
m.x.set_value(3.0)
m.y.set_value(4.0)
res = self.opt.solve(m, use_start=True)
self.assertAlmostEqual(pyo.value(m.x), 3.0, 5)
self.assertAlmostEqual(pyo.value(m.y), 4.0, 5)
self.assertAlmostEqual(res.incumbent_objective, 0.0, 5)
self.assertLessEqual(res.extra_info.number_iters, 1)

def test_warm_start_with_subset_variables(self):
"""Test warm start when only a subset of variables have initial values."""
m = pyo.ConcreteModel()
m.x = pyo.Var(bounds=(0, 10))
m.y = pyo.Var(bounds=(0, 10))
m.obj = pyo.Objective(expr=(m.x - 3) ** 2 + (m.y - 4) ** 2, sense=pyo.minimize)
m.x.set_value(3.0)
m.y.set_value(None)
res = self.opt.solve(m, use_start=True)
self.assertAlmostEqual(pyo.value(m.x), 3.0, 5)
self.assertAlmostEqual(pyo.value(m.y), 4.0, 5)
self.assertAlmostEqual(res.incumbent_objective, 0.0, 5)

def test_warm_start_disabled(self):
"""Test that use_start=False disables warm start."""
m = pyo.ConcreteModel()
m.x = pyo.Var(bounds=(0, 10))
m.y = pyo.Var(bounds=(0, 10))
m.obj = pyo.Objective(expr=(m.x - 3) ** 2 + (m.y - 4) ** 2, sense=pyo.minimize)
m.x.set_value(3.0)
m.y.set_value(4.0)
res = self.opt.solve(m, use_start=False)
self.assertAlmostEqual(pyo.value(m.x), 3.0, 5)
self.assertAlmostEqual(pyo.value(m.y), 4.0, 5)
self.assertAlmostEqual(res.incumbent_objective, 0.0, 5)

def test_warm_start_with_constraints(self):
"""Test warm start with constrained optimization."""
m = pyo.ConcreteModel()
m.x = pyo.Var(bounds=(0, None))
m.y = pyo.Var(bounds=(0, None))
m.obj = pyo.Objective(expr=m.x + m.y, sense=pyo.minimize)
m.c1 = pyo.Constraint(expr=m.x + 2 * m.y >= 4)
m.c2 = pyo.Constraint(expr=2 * m.x + m.y >= 4)
m.x.set_value(1.3)
m.y.set_value(1.3)
self.opt.solve(m, use_start=True)
self.assertAlmostEqual(pyo.value(m.x), 4.0 / 3.0, 3)
self.assertAlmostEqual(pyo.value(m.y), 4.0 / 3.0, 3)
Loading