Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions Examples/Scripts/Python/gnn_module_map_odd.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def runGnnModuleMap(
moduleMapPath,
gnnModel,
outputDir,
events=100,
s=None,
):
"""
Expand Down Expand Up @@ -75,7 +74,7 @@ def runGnnModuleMap(
).exists(), f"Module map not found: {moduleMapPath}.triplets.root"
assert Path(gnnModel).exists(), f"Model file not found: {gnnModel}"

s = s or Sequencer(events=events, numThreads=1)
s = s or Sequencer(events=100, numThreads=1)

# Random number generator
rnd = acts.examples.RandomNumbers(seed=42)
Expand Down Expand Up @@ -240,7 +239,6 @@ def runGnnModuleMap(
moduleMapPath = str(ci_models_odd / "module_map_odd_2k_events.1e-03.float")
gnnModel = str(ci_models_odd / "gnn_odd_module_map.pt")
outputDir = Path.cwd()
events = 100

# Run the workflow
runGnnModuleMap(
Expand All @@ -252,5 +250,4 @@ def runGnnModuleMap(
moduleMapPath=moduleMapPath,
gnnModel=gnnModel,
outputDir=outputDir,
events=events,
)
36 changes: 21 additions & 15 deletions Python/Examples/python/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1817,6 +1817,7 @@ def addTrackWriters(
writeStates: bool = False,
writeFitterPerformance: bool = False,
writeFinderPerformance: bool = False,
writeFinderNTuple: bool = False,
logLevel: Optional[acts.logging.Level] = None,
writeCovMat=False,
):
Expand Down Expand Up @@ -1877,6 +1878,17 @@ def addTrackWriters(
)
s.addWriter(trackFinderPerfWriter)

if writeFinderNTuple:
trackFinderNTupleWriter = acts.examples.root.RootTrackFinderNTupleWriter(
level=customLogLevel(),
inputTracks=tracks,
inputParticles="particles_selected",
inputParticleMeasurementsMap="particle_measurements_map",
inputTrackParticleMatching="track_particle_matching",
filePath=str(Path(outputDirRoot) / f"track_finding_ntuple_{name}.root"),
)
s.addWriter(trackFinderNTupleWriter)

if outputDirCsv is not None:
outputDirCsv = Path(outputDirCsv)
if not outputDirCsv.exists():
Expand Down Expand Up @@ -2024,21 +2036,15 @@ def addGnn(
"particle_track_matching", matchAlg.config.outputParticleTrackMatching
)

# Optional performance writer
if outputDirRoot is not None:
assert (
ACTS_EXAMPLES_ROOT_AVAILABLE
), "ROOT output requested but ROOT is not available"
s.addWriter(
RootTrackFinderNTupleWriter(
level=customLogLevel(),
inputTracks="tracks",
inputParticles="particles",
inputParticleMeasurementsMap="particle_measurements_map",
inputTrackParticleMatching=matchAlg.config.outputTrackParticleMatching,
filePath=str(Path(outputDirRoot) / "performance_track_finding.root"),
)
)
addTrackWriters(
s,
name="gnn",
tracks="tracks",
outputDirRoot=outputDirRoot,
writeFinderPerformance=True,
writeFinderNTuple=True,
logLevel=logLevel,
)

return s

Expand Down
7 changes: 4 additions & 3 deletions Python/Examples/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,9 +1325,10 @@ def test_gnn_module_map(tmp_path, assert_root_hash, backend, hardware):
)

# Verify output
output_file = tmp_path / "performance_track_finding.root"
assert output_file.exists()
assert_root_hash("performance_track_finding.root", output_file)
for f in ["performance_finding_gnn.root", "track_finding_ntuple_gnn.root"]:
output_file = tmp_path / f
assert output_file.exists()
assert_root_hash(f, output_file)


@pytest.mark.odd
Expand Down
Loading