Skip to content

Commit a8764cf

Browse files
committed
Swap order of assertion generation and minimization
1 parent 5d850a0 commit a8764cf

File tree

3 files changed

+186
-22
lines changed

3 files changed

+186
-22
lines changed

src/pynguin/ga/postprocess.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,53 @@
1919
import pynguin.ga.testsuitechromosome as tsc
2020
import pynguin.testcase.testcase as tc
2121
import pynguin.testcase.testcasevisitor as tcv
22-
from pynguin.assertion.assertion import Assertion, ExceptionAssertion
22+
from pynguin.assertion.assertion import (
23+
Assertion,
24+
ExceptionAssertion,
25+
ReferenceAssertion,
26+
)
2327
from pynguin.testcase.statement import StatementVisitor
2428
from pynguin.utils.orderedset import OrderedSet
2529

2630
if TYPE_CHECKING:
2731
import pynguin.ga.computations as ff
32+
import pynguin.testcase.variablereference as vr
2833
from pynguin.testcase.execution import SubprocessTestCaseExecutor
2934

3035
_LOGGER = logging.getLogger(__name__)
3136

3237

38+
def get_assertion_protected_variables(test_case: tc.TestCase) -> set[vr.VariableReference]:
39+
"""Get all variables that should be protected due to assertions.
40+
41+
Variables are protected if they are:
42+
- Directly referenced by a ReferenceAssertion's source, OR
43+
- In the backward dependency chain of an asserted variable
44+
45+
ExceptionAssertions are skipped (no source variable).
46+
47+
Args:
48+
test_case: Test case to analyze
49+
50+
Returns:
51+
Set of variable references that should not be removed during minimization
52+
"""
53+
protected: set[vr.VariableReference] = set()
54+
for stmt in test_case.statements:
55+
for assertion in stmt.assertions:
56+
# Skip ExceptionAssertion - has no source
57+
if isinstance(assertion, ExceptionAssertion):
58+
continue
59+
# Handle ReferenceAssertion
60+
if isinstance(assertion, ReferenceAssertion):
61+
var_ref = assertion.source.get_variable_reference()
62+
if var_ref is not None:
63+
protected.add(var_ref)
64+
# Add all backward dependencies
65+
protected.update(test_case.get_dependencies(var_ref))
66+
return protected
67+
68+
3369
class ExceptionTruncation(cv.ChromosomeVisitor):
3470
"""Truncates test cases after an exception-raising statement."""
3571

@@ -201,7 +237,9 @@ class ForwardIterativeMinimizationVisitor(IterativeMinimizationVisitor):
201237
4. If fitness remains the same or improves, remove the statement from the original test case
202238
"""
203239

204-
def visit_default_test_case(self, test_case: tc.TestCase) -> None: # noqa: D102
240+
def visit_default_test_case( # noqa: D102, PLR0914
241+
self, test_case: tc.TestCase
242+
) -> None:
205243
original_test_case = tcc.TestCaseChromosome(test_case=test_case)
206244
original_test_suite = tsc.TestSuiteChromosome()
207245
original_test_suite.add_test_case_chromosome(original_test_case)
@@ -235,13 +273,18 @@ def visit_default_test_case(self, test_case: tc.TestCase) -> None: # noqa: D102
235273
for fitness_function in self._fitness_functions
236274
]
237275
if all(map(math.isclose, original_coverages, minimized_coverages)):
238-
removed = test_case.remove_statement_with_forward_dependencies(stmt)
239-
self._removed_statements += len(removed)
276+
protected_vars = get_assertion_protected_variables(test_case)
277+
is_statement_protected = stmt.ret_val in protected_vars
278+
if not is_statement_protected:
279+
removed = test_case.remove_statement_with_forward_dependencies(stmt)
280+
self._removed_statements += len(removed)
240281

241-
# Update the statements list to reflect the changes in the test case
242-
statements = list(test_case.statements)
243-
# Don't increment i since we've removed elements and the list has shifted
244-
statements_changed = True
282+
# Update the statements list to reflect the changes in the test case
283+
statements = list(test_case.statements)
284+
# Don't increment i since we've removed elements and the list has shifted
285+
statements_changed = True
286+
else:
287+
i += 1
245288
else:
246289
i += 1
247290

@@ -295,11 +338,16 @@ def visit_default_test_case(self, test_case: tc.TestCase) -> None: # noqa: D102
295338
for fitness_function in self._fitness_functions
296339
]
297340
if all(map(math.isclose, original_coverages, minimized_coverages)):
298-
removed = test_case.remove_statement_with_forward_dependencies(stmt)
299-
self._removed_statements += len(removed)
300-
statements_changed = True
301-
break
302-
i -= 1
341+
protected_vars = get_assertion_protected_variables(test_case)
342+
is_statement_protected = stmt.ret_val in protected_vars
343+
if not is_statement_protected:
344+
removed = test_case.remove_statement_with_forward_dependencies(stmt)
345+
self._removed_statements += len(removed)
346+
statements_changed = True
347+
break
348+
i -= 1
349+
else:
350+
i -= 1
303351

304352
_LOGGER.debug(
305353
"Removed %s statement(s) from test case using backward iterative minimization",

src/pynguin/generator.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -543,13 +543,6 @@ def _track_final_metrics(
543543
runtime_variable, generation_result.get_coverage_for(coverage_ff)
544544
)
545545

546-
ass_gen = config.configuration.test_case_output.assertion_generation
547-
if (
548-
ass_gen == config.AssertionGenerator.CHECKED_MINIMIZING
549-
and RuntimeVariable.AssertionCheckedCoverage in output_variables
550-
):
551-
_minimize_assertions(generation_result)
552-
553546
# Collect other final stats on result
554547
stat.track_output_variable(RuntimeVariable.FinalLength, generation_result.length())
555548
stat.track_output_variable(RuntimeVariable.FinalSize, generation_result.size())
@@ -626,12 +619,27 @@ def _run() -> ReturnCode: # noqa: C901
626619
executor.clear_remote_observers()
627620

628621
_track_search_metrics(algorithm, generation_result, coverage_metrics)
622+
623+
# Generate assertions FIRST
624+
_generate_assertions(executor, generation_result, test_cluster)
625+
626+
# Minimize assertions if configured (requires re-instrumentation for checked_instructions)
627+
ass_gen = config.configuration.test_case_output.assertion_generation
628+
if (
629+
ass_gen == config.AssertionGenerator.CHECKED_MINIMIZING
630+
and RuntimeVariable.AssertionCheckedCoverage
631+
in config.configuration.statistics_output.output_variables
632+
):
633+
_prepare_for_assertion_minimization(executor, generation_result, constant_provider)
634+
_LOGGER.info("Minimizing assertions")
635+
_minimize_assertions(generation_result)
636+
637+
# Statement minimization LAST (now assertion-aware)
629638
try:
630639
_LOGGER.info("Minimizing test cases")
631640
_minimize(generation_result, algorithm)
632641
except Exception as ex:
633642
_LOGGER.exception("Minimization failed: %s", ex)
634-
_generate_assertions(executor, generation_result, test_cluster)
635643

636644
if (
637645
tracked_metrics := _track_final_metrics(
@@ -841,6 +849,55 @@ def _minimize_assertions(generation_result: tsc.TestSuiteChromosome):
841849
)
842850

843851

852+
def _prepare_for_assertion_minimization(
853+
executor: TestCaseExecutor,
854+
generation_result: tsc.TestSuiteChromosome,
855+
constant_provider: ConstantProvider,
856+
) -> None:
857+
"""Prepare for assertion minimization by populating checked_instructions.
858+
859+
Re-instruments the module with RemoteAssertionExecutionObserver,
860+
executes the test suite to populate checked_instructions on assertions,
861+
then returns executor to non-instrumented state.
862+
863+
Args:
864+
executor: Test case executor
865+
generation_result: Test suite with assertions to minimize
866+
constant_provider: Constant provider for re-instrumentation
867+
"""
868+
# Add assertion execution observer
869+
executor.set_instrument(True)
870+
executor.add_remote_observer(RemoteAssertionExecutionObserver())
871+
872+
# Re-instrument with CHECKED metric
873+
metrics_for_reinstrumentation = {config.CoverageMetric.CHECKED}
874+
dynamic_constant_provider = None
875+
if isinstance(constant_provider, DynamicConstantProvider):
876+
dynamic_constant_provider = constant_provider
877+
878+
if not _reload_instrumentation_loader(
879+
metrics_for_reinstrumentation,
880+
dynamic_constant_provider,
881+
executor.subject_properties,
882+
):
883+
_LOGGER.warning("Failed to reload instrumentation for assertion minimization")
884+
return
885+
886+
# Force re-execution to populate checked_instructions
887+
_reset_cache_for_result(generation_result)
888+
889+
# Execute to populate checked_instructions
890+
# This happens automatically when coverage is computed
891+
assertion_checked_ff = ff.TestSuiteAssertionCheckedCoverageFunction(executor)
892+
generation_result.add_coverage_function(assertion_checked_ff)
893+
_ = generation_result.get_coverage_for(assertion_checked_ff)
894+
895+
# Return executor to non-instrumented state
896+
executor.set_instrument(False)
897+
executor.clear_observers()
898+
executor.clear_remote_observers()
899+
900+
844901
_strategies: dict[config.MutationStrategy, Callable[[int], ms.HOMStrategy]] = {
845902
config.MutationStrategy.FIRST_TO_LAST: ms.FirstToLastHOMStrategy,
846903
config.MutationStrategy.BETWEEN_OPERATORS: ms.BetweenOperatorsHOMStrategy,

tests/ga/test_postprocess.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
import pynguin.testcase.defaulttestcase as dtc
1818
import pynguin.testcase.statement as stmt
1919
from pynguin.analyses.module import ModuleTestCluster, generate_test_cluster
20-
from pynguin.assertion.assertion import ExceptionAssertion
20+
from pynguin.assertion.assertion import ExceptionAssertion, ObjectAssertion
2121
from pynguin.ga.computations import TestSuiteBranchCoverageFunction, TestSuiteLineCoverageFunction
22+
from pynguin.ga.postprocess import get_assertion_protected_variables
2223
from pynguin.ga.testsuitechromosome import TestSuiteChromosome
2324
from pynguin.instrumentation.machinery import install_import_hook
2425
from pynguin.instrumentation.tracer import SubjectProperties
@@ -122,6 +123,64 @@ def test_test_case_assertion_minimization_does_not_remove_empty_assertion(
122123
assert default_test_case.get_assertions() == [assertion_1]
123124

124125

126+
def test_get_assertion_protected_variables_no_assertions(default_test_case):
127+
"""Test that a test case with no assertions returns empty protected set."""
128+
statement = stmt.IntPrimitiveStatement(default_test_case)
129+
default_test_case.add_statement(statement)
130+
131+
result = get_assertion_protected_variables(default_test_case)
132+
133+
assert result == set()
134+
135+
136+
def test_get_assertion_protected_variables_single_reference_assertion(default_test_case):
137+
"""Test that ReferenceAssertion source variable is protected."""
138+
statement = stmt.IntPrimitiveStatement(default_test_case, value=42)
139+
default_test_case.add_statement(statement)
140+
141+
# Create ObjectAssertion on the statement's return value
142+
var_ref = statement.ret_val
143+
assertion = ObjectAssertion(var_ref, 42) # Asserts var_ref == 42
144+
statement.add_assertion(assertion)
145+
146+
result = get_assertion_protected_variables(default_test_case)
147+
148+
assert var_ref in result
149+
150+
151+
def test_get_assertion_protected_variables_skips_exception_assertion(default_test_case):
152+
"""Test that ExceptionAssertion (no source) doesn't add to protected set."""
153+
statement = stmt.IntPrimitiveStatement(default_test_case)
154+
default_test_case.add_statement(statement)
155+
156+
# ExceptionAssertion has no source variable
157+
exc_assertion = MagicMock(spec=ExceptionAssertion)
158+
statement.add_assertion(exc_assertion)
159+
160+
result = get_assertion_protected_variables(default_test_case)
161+
162+
assert result == set()
163+
164+
165+
def test_get_assertion_protected_variables_with_dependencies(default_test_case):
166+
"""Test that assertion source and its dependencies are protected."""
167+
# Create a statement
168+
statement = stmt.IntPrimitiveStatement(default_test_case, value=42)
169+
default_test_case.add_statement(statement)
170+
var_ref = statement.ret_val
171+
172+
# Add assertion
173+
assertion = ObjectAssertion(var_ref, 42)
174+
statement.add_assertion(assertion)
175+
176+
result = get_assertion_protected_variables(default_test_case)
177+
178+
# var_ref should be protected
179+
assert var_ref in result
180+
# Since IntPrimitiveStatement has no dependencies, only var_ref is protected
181+
# This is correct behavior - primitives have no backward deps
182+
183+
125184
def test_test_case_postprocessor_suite():
126185
dummy_visitor = MagicMock()
127186
tcpp = pp.TestCasePostProcessor([dummy_visitor])

0 commit comments

Comments
 (0)