From 3a793712d268af3268f53d0cb63365204d696afc Mon Sep 17 00:00:00 2001 From: Andreas Stefl Date: Thu, 19 Mar 2026 18:42:11 +0100 Subject: [PATCH 1/6] Proposal: CRTP track state creator --- .../Acts/TrackFinding/TrackStateCreator.hpp | 508 +++++++++++++++++- .../src/TrackFindingAlgorithm.cpp | 56 +- .../src/GeoMuonMockupExperiment.cpp | 7 +- 3 files changed, 556 insertions(+), 15 deletions(-) diff --git a/Core/include/Acts/TrackFinding/TrackStateCreator.hpp b/Core/include/Acts/TrackFinding/TrackStateCreator.hpp index 31410b0c246..9189c0d96ae 100644 --- a/Core/include/Acts/TrackFinding/TrackStateCreator.hpp +++ b/Core/include/Acts/TrackFinding/TrackStateCreator.hpp @@ -8,12 +8,508 @@ #pragma once -// for definitions of Calibrator, MeasurementSelector +#include "Acts/EventData/BoundTrackParameters.hpp" +#include "Acts/EventData/SubspaceHelpers.hpp" +#include "Acts/EventData/TrackStateType.hpp" +#include "Acts/EventData/Types.hpp" +#include "Acts/Surfaces/Surface.hpp" #include "Acts/TrackFinding/CombinatorialKalmanFilterExtensions.hpp" #include "Acts/TrackFitting/KalmanFitter.hpp" +#include "Acts/Utilities/Enumerate.hpp" namespace Acts { +namespace Experimental { + +// TODO surface token / detector element token +// TODO state object +// TODO cuts object +template +class TrackStateCreatorBase { + public: + /// Type alias for result of track states creation operation + using TrackStatesResult = Result>; + /// Type alias for bound state tuple containing parameters, jacobian and path + /// length + using BoundState = std::tuple; + + template + TrackStatesResult createTrackStates( + const GeometryContext& gctx, const CalibrationContext& cctx, + const Surface& surface, const BoundState& boundState, + const TrackIndexType prevTip, + std::vector& trackStateCandidates, + track_state_container_backend_t& trajectory, const Logger& logger) const { + static_cast(trackStateCandidates); + + TrackStatesResult result = TrackStatesResult::success({}); + + ACTS_VERBOSE("Perform measurement selection for surface " + << surface.geometryId()); + + const auto& [boundParams, jacobian, pathLength] = boundState; + + const auto measurementRange = + derived().measurementRange(surface, boundParams); + ACTS_VERBOSE("Found " << measurementRange.size() + << " measurements on surface " + << surface.geometryId()); + if (measurementRange.begin() == measurementRange.end()) { + ACTS_VERBOSE("No measurements on surface " + << surface.geometryId() << ". No track states created."); + return result; + } + + const std::uint32_t maxNumSelectedMeasurements = + derived().getMaxNumSelectedMeasurements(surface); + const float maxChi2Compatible = derived().getMaxChi2Compatible(surface); + const float maxChi2Outlier = derived().getMaxChi2Outlier(surface); + + ACTS_VERBOSE("Max number of selected measurements: " + << maxNumSelectedMeasurements + << ", max chi2 compatible: " << maxChi2Compatible + << ", max chi2 outlier: " << maxChi2Outlier); + + auto selectedMeasurements = derived().selectMeasurements( + gctx, cctx, surface, measurementRange, boundParams, + maxNumSelectedMeasurements, maxChi2Compatible, maxChi2Outlier); + + ACTS_VERBOSE("Selected " << selectedMeasurements.size() + << " measurements for surface " + << surface.geometryId()); + + derived().sortSelectedMeasurements(surface, selectedMeasurements, + maxNumSelectedMeasurements); + + derived().trimSelectedMeasurements(surface, selectedMeasurements, + maxNumSelectedMeasurements); + + ACTS_VERBOSE("Trimmed to " << selectedMeasurements.size() + << " measurements for surface " + << surface.geometryId()); + + derived().createTrackStates(gctx, cctx, surface, boundState, prevTip, + trajectory, measurementRange, + selectedMeasurements, logger); + + return result; + } + + protected: + enum class MeasurementClassification : std::uint8_t { + Incompatible, + Outlier, + Compatible + }; + + struct MeasurementCandidate { + std::uint32_t index{std::numeric_limits::max()}; + float chi2{std::numeric_limits::max()}; + MeasurementClassification classification{ + MeasurementClassification::Incompatible}; + }; + + struct SelectedMeasurements { + MeasurementCandidate bestCandidate; + CkfTypes::BranchVector nonIncompatible; + + std::size_t size() const { return nonIncompatible.size(); } + bool empty() const { return nonIncompatible.empty(); } + + auto begin() const { return nonIncompatible.begin(); } + auto end() const { return nonIncompatible.end(); } + + void push_back(const MeasurementCandidate& candidate) { + if (candidate.chi2 < bestCandidate.chi2) { + bestCandidate = candidate; + } + if (candidate.classification != MeasurementClassification::Incompatible) { + nonIncompatible.push_back(candidate); + } + } + + void emplace_back(std::uint32_t index, float chi2, + MeasurementClassification classification) { + push_back({index, chi2, classification}); + } + + void sort(std::uint32_t maxNumSelectedMeasurements) { + if (maxNumSelectedMeasurements < nonIncompatible.size()) { + nonIncompatible.resize(maxNumSelectedMeasurements); + } + + std::ranges::sort( + nonIncompatible, {}, + [](const MeasurementCandidate& candidate) { return candidate.chi2; }); + } + + void resize(std::uint32_t maxNumSelectedMeasurements) { + if (maxNumSelectedMeasurements < nonIncompatible.size()) { + nonIncompatible.resize(maxNumSelectedMeasurements); + } + } + }; + + template + SelectedMeasurements selectMeasurements( + const GeometryContext& gctx, const CalibrationContext& cctx, + const Surface& surface, const measurement_range_t& range, + const BoundTrackParameters& boundParams, + std::uint32_t maxNumSelectedMeasurements, float maxChi2Compatible, + float maxChi2Outlier) const { + static_cast(gctx); + static_cast(cctx); + static_cast(maxNumSelectedMeasurements); + + SelectedMeasurements result; + + for (const auto& [i, measurement] : + enumerate(range)) { + const float chi2 = + derived().computeChi2(surface, measurement, boundParams); + const MeasurementClassification classification = + derived().classifyMeasurement(surface, measurement, chi2, + maxChi2Compatible, maxChi2Outlier); + result.emplace_back(i, chi2, classification); + } + + return result; + } + + void sortSelectedMeasurements( + const Surface& surface, SelectedMeasurements& selectedMeasurements, + std::uint32_t maxNumSelectedMeasurements) const { + static_cast(surface); + + selectedMeasurements.sort(maxNumSelectedMeasurements); + } + + void trimSelectedMeasurements( + const Surface& surface, SelectedMeasurements& selectedMeasurements, + std::uint32_t maxNumSelectedMeasurements) const { + static_cast(surface); + + selectedMeasurements.resize(maxNumSelectedMeasurements); + } + + std::uint32_t getMaxNumSelectedMeasurements(const Surface& surface) const { + static_cast(surface); + + return std::numeric_limits::max(); + } + + float getMaxChi2Compatible(const Surface& surface) const { + static_cast(surface); + + return std::numeric_limits::max(); + } + + float getMaxChi2Outlier(const Surface& surface) const { + static_cast(surface); + + return std::numeric_limits::max(); + } + + template + float computeChi2(const Surface& surface, const measurement_t& measurement, + const BoundTrackParameters& boundParams) const { + static_cast(surface); + + const auto& subspaceHelper = + derived().measurementSubspace(surface, measurement); + return derived().computeChi2(surface, measurement, boundParams, + subspaceHelper); + } + + template + double computeChi2(const Surface& surface, const measurement_t& measurement, + const BoundTrackParameters& boundParams, + FixedBoundSubspaceHelper subspaceHelper) const { + static_cast(surface); + + const Vector predictedParameters = + subspaceHelper.projectVector(boundParams.parameters()); + const SquareMatrix predictedCovariance = + subspaceHelper.projectMatrix(*boundParams.covariance()); + + const Vector measuredParameters = + derived().measuredParameters(surface, measurement); + const SquareMatrix measurementCovariance = + derived().measurementCovariance(surface, measurement); + + const Vector residualParameters = + measuredParameters - predictedParameters; + const SquareMatrix residualCovariance = + predictedCovariance + measurementCovariance; + + const double chi2 = (residualParameters.transpose() * + residualCovariance.inverse() * residualParameters) + .eval()(0, 0); + return chi2; + } + + template + double computeChi2(const Surface& surface, const measurement_t& measurement, + const BoundTrackParameters& boundParams, + VariableBoundSubspaceHelper subspaceHelper) const { + static_cast(surface); + + return visit_measurement(subspaceHelper.size(), [&](auto N) -> double { + constexpr std::size_t kDim = decltype(N)::value; + const FixedBoundSubspaceHelper fixedSubspaceHelper(subspaceHelper); + return derived().computeChi2(surface, measurement, boundParams, + fixedSubspaceHelper); + }); + } + + template + MeasurementClassification classifyMeasurement( + const Surface& surface, const measurement_t& measurement, + const float chi2, const float maxChi2Compatible, + const float maxChi2Outlier) const { + static_cast(surface); + static_cast(measurement); + + if (chi2 < maxChi2Compatible) { + return MeasurementClassification::Compatible; + } + if (chi2 < maxChi2Outlier) { + return MeasurementClassification::Outlier; + } + return MeasurementClassification::Incompatible; + } + + template + auto selectedMeasurementSourceLink( + const Surface& surface, const measurement_range_t& measurements, + const MeasurementCandidate& selectedMeasurement) const { + return derived().measurementSourceLink( + surface, measurements[selectedMeasurement.index]); + } + + template + auto selectedMeasurementSubspace( + const Surface& surface, const measurement_range_t& measurements, + const MeasurementCandidate& selectedMeasurement) const { + return derived().measurementSubspace( + surface, measurements[selectedMeasurement.index]); + } + + template + auto selectedMeasurementParameters( + const Surface& surface, const measurement_range_t& measurements, + const MeasurementCandidate& selectedMeasurement) const { + return derived().measuredParameters( + surface, measurements[selectedMeasurement.index]); + } + + template + auto selectedMeasurementCovariance( + const Surface& surface, const measurement_range_t& measurements, + const MeasurementCandidate& selectedMeasurement) const { + return derived().measurementCovariance( + surface, measurements[selectedMeasurement.index]); + } + + template + void createTrackStates( + const GeometryContext& gctx, const CalibrationContext& cctx, + const Surface& surface, const BoundState& boundState, + const TrackIndexType prevTip, track_state_container_backend_t& trajectory, + const measurement_range_t& measurements, + const selected_measurement_range_t& selectedMeasurements, + const Logger& logger) const { + using TrackStateProxy = + typename track_state_container_backend_t::TrackStateProxy; + + std::optional firstTrackState; + + if (derived().hasHole(surface, boundState, selectedMeasurements, logger)) { + ACTS_VERBOSE("No compatible measurements on surface " + << surface.geometryId() << ". No track states created."); + derived().createHoleState(surface, boundState, prevTip, trajectory, + firstTrackState, logger); + } + + for (const auto& selectedMeasurement : selectedMeasurements) { + derived().createMeasurementState( + gctx, cctx, surface, boundState, prevTip, trajectory, firstTrackState, + measurements, selectedMeasurement, logger); + } + } + + bool hasHole(const Surface& surface, const BoundState& boundState, + const SelectedMeasurements& selectedMeasurements, + const Logger& logger) const { + static_cast(surface); + static_cast(boundState); + static_cast(logger); + + return selectedMeasurements.empty() && + (selectedMeasurements.bestCandidate.classification == + MeasurementClassification::Incompatible); + } + + TrackStateType determineTrackStateType(const Surface& surface, + const Logger& logger) const { + static_cast(logger); + + TrackStateType result; + + result.setHasParameters(); + + result.setHasMeasurement(); + + if (surface.surfaceMaterial() != nullptr) { + result.setHasMaterial(); + } + + return result; + } + + template + TrackStateType determineTrackStateType( + const Surface& surface, const measurement_range_t& measurements, + const MeasurementCandidate& selectedMeasurement, + const Logger& logger) const { + static_cast(measurements); + + TrackStateType result = derived().determineTrackStateType(surface, logger); + + if (selectedMeasurement.classification == + MeasurementClassification::Outlier) { + result.setIsOutlier(); + } + + return result; + } + + template + track_state_proxy_t createTrackState( + const Surface& surface, const BoundState& boundState, + const TrackIndexType prevTip, track_state_container_backend_t& trajectory, + std::optional& firstTrackState, + const TrackStateType trackStateType, const Logger& logger) const { + static_cast(logger); + + TrackStatePropMask mask = TrackStatePropMask::None; + if (!firstTrackState.has_value()) { + mask |= TrackStatePropMask::Predicted | TrackStatePropMask::Jacobian; + } + + track_state_proxy_t trackState = trajectory.makeTrackState(mask, prevTip); + + trackState.setReferenceSurface(surface.getSharedPtr()); + + trackState.typeFlags() = trackStateType; + + if (!firstTrackState.has_value()) { + const auto& [boundParams, jacobian, pathLength] = boundState; + trackState.predicted() = boundParams.parameters(); + trackState.predictedCovariance() = *boundParams.covariance(); + trackState.jacobian() = jacobian; + trackState.pathLength() = pathLength; + } else { + trackState.shareFrom(*firstTrackState, TrackStatePropMask::Predicted); + trackState.shareFrom(*firstTrackState, TrackStatePropMask::Jacobian); + } + + if (!firstTrackState.has_value()) { + firstTrackState = trackState; + } + return trackState; + } + + template + track_state_proxy_t createHoleState( + const Surface& surface, const BoundState& boundState, + const TrackIndexType prevTip, track_state_container_backend_t& trajectory, + std::optional& firstTrackState, + const Logger& logger) const { + TrackStateType trackStateType = + derived().determineTrackStateType(surface, logger); + trackStateType.setIsHole(); + + track_state_proxy_t trackState = + derived().createTrackState(surface, boundState, prevTip, trajectory, + firstTrackState, trackStateType, logger); + + return trackState; + } + + template + track_state_proxy_t createMeasurementState( + const GeometryContext& gctx, const CalibrationContext& cctx, + const Surface& surface, const BoundState& boundState, + const TrackIndexType prevTip, track_state_container_backend_t& trajectory, + std::optional& firstTrackState, + const measurement_range_t& measurements, + const selected_measurement_t& selectedMeasurement, + const Logger& logger) const { + const TrackStateType trackStateType = derived().determineTrackStateType( + surface, measurements, selectedMeasurement, logger); + + track_state_proxy_t trackState = + derived().createTrackState(surface, boundState, prevTip, trajectory, + firstTrackState, trackStateType, logger); + + trackState.addComponents(TrackStatePropMask::Calibrated); + + derived().postCalibrateTrackState(gctx, cctx, surface, boundState, + measurements, selectedMeasurement, + trackState); + + return trackState; + } + + template + void postCalibrateTrackState( + const GeometryContext& gctx, const CalibrationContext& cctx, + const Surface& surface, const BoundState& boundState, + const measurement_range_t& measurements, + const selected_measurement_t& selectedMeasurement, + track_state_proxy_t& trackState) const { + static_cast(gctx); + static_cast(cctx); + static_cast(boundState); + + trackState.setUncalibratedSourceLink( + derived().selectedMeasurementSourceLink(surface, measurements, + selectedMeasurement)); + + const auto& subspaceHelper = derived().selectedMeasurementSubspace( + surface, measurements, selectedMeasurement); + + visit_measurement(subspaceHelper.size(), [&](auto N) -> void { + constexpr std::size_t kDim = decltype(N)::value; + const FixedBoundSubspaceHelper fixedSubspaceHelper(subspaceHelper); + const Vector measuredParameters = + derived().selectedMeasurementParameters(surface, measurements, + selectedMeasurement); + const SquareMatrix measurementCovariance = + derived().selectedMeasurementCovariance(surface, measurements, + selectedMeasurement); + trackState.allocateCalibrated(measuredParameters, measurementCovariance); + }); + + trackState.setProjectorSubspaceIndices(subspaceHelper.indices()); + } + + private: + const derived_t& derived() const { + return static_cast(*this); + } +}; + +} // namespace Experimental + /// @brief Create track states for selected measurements associated to a surface. /// /// - First get a source link range covering relevant measurements associated to @@ -35,8 +531,7 @@ namespace Acts { template struct TrackStateCreator { /// Type alias for result of track states creation operation - using TrackStatesResult = - Acts::Result>; + using TrackStatesResult = Result>; /// Type alias for track state container backend from track container using TrackStateContainerBackend = typename track_container_t::TrackStateContainerBackend; @@ -84,7 +579,6 @@ struct TrackStateCreator { MeasurementSelector measurementSelector{ DelegateFuncTag{}}; - public: /// @brief extend the trajectory onto the given surface. /// /// @param gctx The geometry context to be used for this task @@ -146,8 +640,7 @@ struct TrackStateCreator { TrackStateContainerBackend& trajectory, const Logger& logger) const { using PM = TrackStatePropMask; - using ResultTrackStateList = - Acts::Result>; + using ResultTrackStateList = Result>; ResultTrackStateList resultTrackStateList{ CkfTypes::BranchVector()}; const auto& [boundParams, jacobian, pathLength] = boundState; @@ -237,8 +730,7 @@ struct TrackStateCreator { const Logger& logger) const { using PM = TrackStatePropMask; - using ResultTrackStateList = - Acts::Result>; + using ResultTrackStateList = Result>; ResultTrackStateList resultTrackStateList{ CkfTypes::BranchVector()}; CkfTypes::BranchVector& trackStateList = diff --git a/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp b/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp index 4b80a458b70..973506f3837 100644 --- a/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp +++ b/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp @@ -11,9 +11,11 @@ #include "Acts/Definitions/Algebra.hpp" #include "Acts/Definitions/Direction.hpp" #include "Acts/Definitions/TrackParametrization.hpp" +#include "Acts/EventData/BoundTrackParameters.hpp" #include "Acts/EventData/MultiTrajectory.hpp" #include "Acts/EventData/ProxyAccessor.hpp" #include "Acts/EventData/SourceLink.hpp" +#include "Acts/EventData/SubspaceHelpers.hpp" #include "Acts/EventData/TrackContainer.hpp" #include "Acts/EventData/VectorMultiTrajectory.hpp" #include "Acts/EventData/VectorTrackContainer.hpp" @@ -38,6 +40,7 @@ #include "ActsExamples/EventData/Track.hpp" #include "ActsExamples/Framework/AlgorithmContext.hpp" #include "ActsExamples/Framework/ProcessCode.hpp" +#include "ActsExamples/Utilities/Range.hpp" #include #include @@ -67,6 +70,49 @@ namespace ActsExamples { namespace { +struct TrackStateCreator final + : public Acts::Experimental::TrackStateCreatorBase { + const MeasurementContainer* measurements = nullptr; + + auto measurementRange(const Acts::Surface& surface, + const Acts::BoundTrackParameters& boundParams) const { + static_cast(boundParams); + + const auto rangePair = + measurements->orderedIndices().equal_range(surface.geometryId()); + + return std::ranges::subrange(rangePair.first, rangePair.second); + } + + Acts::SourceLink measurementSourceLink( + const Acts::Surface& surface, const IndexSourceLink& measurement) const { + static_cast(surface); + + return Acts::SourceLink{measurement}; + } + + Acts::VariableBoundSubspaceHelper measurementSubspace( + const Acts::Surface& surface, const IndexSourceLink& measurement) const { + static_cast(surface); + + return measurements->at(measurement.index()).subspaceHelper(); + } + + auto measuredParameters(const Acts::Surface& surface, + const IndexSourceLink& measurement) const { + static_cast(surface); + + return measurements->at(measurement.index()).parameters(); + } + + auto measurementCovariance(const Acts::Surface& surface, + const IndexSourceLink& measurement) const { + static_cast(surface); + + return measurements->at(measurement.index()).covariance(); + } +}; + class MeasurementSelector { public: using Traj = Acts::VectorMultiTrajectory; @@ -338,13 +384,21 @@ ProcessCode TrackFindingAlgorithm::execute(const AlgorithmContext& ctx) const { trackStateCreator.measurementSelector .template connect<&MeasurementSelector::select>(&measSel); + TrackStateCreator trackStateCreator2; + trackStateCreator2.measurements = &measurements; + Extensions extensions; extensions.updater.connect<&Acts::GainMatrixUpdater::operator()< typename TrackContainer::TrackStateContainerBackend>>(&kfUpdater); extensions.branchStopper.connect<&BranchStopper::operator()>(&branchStopper); extensions.createTrackStates - .template connect<&TrackStateCreatorType ::createTrackStates>( + .template connect<&TrackStateCreatorType::createTrackStates>( &trackStateCreator); + extensions.createTrackStates + .template connect<&TrackStateCreator::createTrackStates< + TrackStateProxy, + typename TrackContainer::TrackStateContainerBackend>>( + &trackStateCreator2); Acts::PropagatorPlainOptions firstPropOptions(ctx.geoContext, ctx.magFieldContext); diff --git a/Examples/Detectors/MuonSpectrometerMockupDetector/src/GeoMuonMockupExperiment.cpp b/Examples/Detectors/MuonSpectrometerMockupDetector/src/GeoMuonMockupExperiment.cpp index db474816b8e..6fef48accc5 100644 --- a/Examples/Detectors/MuonSpectrometerMockupDetector/src/GeoMuonMockupExperiment.cpp +++ b/Examples/Detectors/MuonSpectrometerMockupDetector/src/GeoMuonMockupExperiment.cpp @@ -170,12 +170,7 @@ ActsPlugins::GeoModelTree GeoMuonMockupExperiment::constructMS() { for (const auto& [fpV, pubKey] : m_publisher->getPublishedFPV()) { try { const auto key = [pubKey]() { - if constexpr (std::is_same_v, - std::string>) { - return std::any_cast(pubKey); - } else { - return std::get(pubKey); - } + return std::any_cast(pubKey); }(); if (!publishedVol From e80220e3a67d998ba9c59977ff203a6fc5dea1af Mon Sep 17 00:00:00 2001 From: Andreas Stefl Date: Fri, 20 Mar 2026 10:08:04 +0100 Subject: [PATCH 2/6] revert --- .../src/GeoMuonMockupExperiment.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/Examples/Detectors/MuonSpectrometerMockupDetector/src/GeoMuonMockupExperiment.cpp b/Examples/Detectors/MuonSpectrometerMockupDetector/src/GeoMuonMockupExperiment.cpp index 6fef48accc5..3b7e56614ad 100644 --- a/Examples/Detectors/MuonSpectrometerMockupDetector/src/GeoMuonMockupExperiment.cpp +++ b/Examples/Detectors/MuonSpectrometerMockupDetector/src/GeoMuonMockupExperiment.cpp @@ -169,8 +169,13 @@ ActsPlugins::GeoModelTree GeoMuonMockupExperiment::constructMS() { VolumeMap_t publishedVol{}; for (const auto& [fpV, pubKey] : m_publisher->getPublishedFPV()) { try { - const auto key = [pubKey]() { - return std::any_cast(pubKey); + const auto key = [&pubKey]() { + if constexpr (std::is_same_v, + std::any>) { + return std::any_cast(pubKey); + } else { + return std::get(pubKey); + } }(); if (!publishedVol From ec3bf9164bf6c25ae2afb9e247ee6488733cab1f Mon Sep 17 00:00:00 2001 From: Andreas Stefl Date: Fri, 20 Mar 2026 10:10:50 +0100 Subject: [PATCH 3/6] trim --- .../src/TrackFindingAlgorithm.cpp | 121 +++++++----------- 1 file changed, 43 insertions(+), 78 deletions(-) diff --git a/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp b/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp index 973506f3837..fa713570585 100644 --- a/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp +++ b/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp @@ -40,7 +40,6 @@ #include "ActsExamples/EventData/Track.hpp" #include "ActsExamples/Framework/AlgorithmContext.hpp" #include "ActsExamples/Framework/ProcessCode.hpp" -#include "ActsExamples/Utilities/Range.hpp" #include #include @@ -72,7 +71,45 @@ namespace { struct TrackStateCreator final : public Acts::Experimental::TrackStateCreatorBase { + using Traj = Acts::VectorMultiTrajectory; + const MeasurementContainer* measurements = nullptr; + std::optional seed; + + bool isSeedCandidate(const Traj::TrackStateProxy& candidate) const { + assert(candidate.hasUncalibratedSourceLink()); + assert(seed.has_value()); + + const Acts::SourceLink& sourceLink = candidate.getUncalibratedSourceLink(); + + for (const ConstSpacePointProxy sp : seed->spacePoints()) { + for (const Acts::SourceLink& sl : sp.sourceLinks()) { + if (sourceLink.get() == sl.get()) { + return true; + } + } + } + + return false; + } + + // TODO + void select(std::vector& candidates, bool& isOutlier, + const Acts::Logger& logger) const { + if (seed.has_value()) { + std::vector newCandidates; + + for (const auto& candidate : candidates) { + if (isSeedCandidate(candidate)) { + newCandidates.push_back(candidate); + } + } + + if (!newCandidates.empty()) { + candidates = std::move(newCandidates); + } + } + } auto measurementRange(const Acts::Surface& surface, const Acts::BoundTrackParameters& boundParams) const { @@ -113,59 +150,6 @@ struct TrackStateCreator final } }; -class MeasurementSelector { - public: - using Traj = Acts::VectorMultiTrajectory; - - explicit MeasurementSelector(Acts::MeasurementSelector selector) - : m_selector(std::move(selector)) {} - - void setSeed(const std::optional& seed) { m_seed = seed; } - - Acts::Result::iterator, - std::vector::iterator>> - select(std::vector& candidates, bool& isOutlier, - const Acts::Logger& logger) const { - if (m_seed.has_value()) { - std::vector newCandidates; - - for (const auto& candidate : candidates) { - if (isSeedCandidate(candidate)) { - newCandidates.push_back(candidate); - } - } - - if (!newCandidates.empty()) { - candidates = std::move(newCandidates); - } - } - - return m_selector.select(candidates, isOutlier, - logger); - } - - private: - Acts::MeasurementSelector m_selector; - std::optional m_seed; - - bool isSeedCandidate(const Traj::TrackStateProxy& candidate) const { - assert(candidate.hasUncalibratedSourceLink()); - assert(m_seed.has_value()); - - const Acts::SourceLink& sourceLink = candidate.getUncalibratedSourceLink(); - - for (const ConstSpacePointProxy sp : m_seed->spacePoints()) { - for (const Acts::SourceLink& sl : sp.sourceLinks()) { - if (sourceLink.get() == sl.get()) { - return true; - } - } - } - - return false; - } -}; - /// Source link indices of the bottom, middle, top measurements. /// In case of strip seeds only the first source link of the pair is used. using SeedIdentifier = std::array; @@ -367,38 +351,19 @@ ProcessCode TrackFindingAlgorithm::execute(const AlgorithmContext& ctx) const { using Extensions = Acts::CombinatorialKalmanFilterExtensions; BranchStopper branchStopper(m_cfg); - MeasurementSelector measSel{ - Acts::MeasurementSelector(m_cfg.measurementSelectorCfg)}; - - IndexSourceLinkAccessor slAccessor; - slAccessor.container = &measurements.orderedIndices(); - - using TrackStateCreatorType = - Acts::TrackStateCreator; - TrackStateCreatorType trackStateCreator; - trackStateCreator.sourceLinkAccessor - .template connect<&IndexSourceLinkAccessor::range>(&slAccessor); - trackStateCreator.calibrator - .template connect<&MeasurementCalibratorAdapter::calibrate>(&calibrator); - trackStateCreator.measurementSelector - .template connect<&MeasurementSelector::select>(&measSel); - - TrackStateCreator trackStateCreator2; - trackStateCreator2.measurements = &measurements; + + TrackStateCreator trackStateCreator; + trackStateCreator.measurements = &measurements; Extensions extensions; extensions.updater.connect<&Acts::GainMatrixUpdater::operator()< typename TrackContainer::TrackStateContainerBackend>>(&kfUpdater); extensions.branchStopper.connect<&BranchStopper::operator()>(&branchStopper); - extensions.createTrackStates - .template connect<&TrackStateCreatorType::createTrackStates>( - &trackStateCreator); extensions.createTrackStates .template connect<&TrackStateCreator::createTrackStates< TrackStateProxy, typename TrackContainer::TrackStateContainerBackend>>( - &trackStateCreator2); + &trackStateCreator); Acts::PropagatorPlainOptions firstPropOptions(ctx.geoContext, ctx.magFieldContext); @@ -523,7 +488,7 @@ ProcessCode TrackFindingAlgorithm::execute(const AlgorithmContext& ctx) const { } if (m_cfg.stayOnSeed) { - measSel.setSeed(seed); + trackStateCreator.seed = seed; } } From 12f9cc7f9177e73354af8f34556b0dd8cb3b0caf Mon Sep 17 00:00:00 2001 From: Andreas Stefl Date: Fri, 20 Mar 2026 15:34:42 +0100 Subject: [PATCH 4/6] fix --- .../Acts/TrackFinding/TrackStateCreator.hpp | 1 + .../src/TrackFindingAlgorithm.cpp | 34 ++++++++++--------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/Core/include/Acts/TrackFinding/TrackStateCreator.hpp b/Core/include/Acts/TrackFinding/TrackStateCreator.hpp index 9189c0d96ae..4d17f23e8d9 100644 --- a/Core/include/Acts/TrackFinding/TrackStateCreator.hpp +++ b/Core/include/Acts/TrackFinding/TrackStateCreator.hpp @@ -16,6 +16,7 @@ #include "Acts/TrackFinding/CombinatorialKalmanFilterExtensions.hpp" #include "Acts/TrackFitting/KalmanFitter.hpp" #include "Acts/Utilities/Enumerate.hpp" +#include "Acts/Utilities/Logger.hpp" namespace Acts { diff --git a/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp b/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp index fa713570585..552982f98ff 100644 --- a/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp +++ b/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp @@ -71,6 +71,7 @@ namespace { struct TrackStateCreator final : public Acts::Experimental::TrackStateCreatorBase { + using Base = Acts::Experimental::TrackStateCreatorBase; using Traj = Acts::VectorMultiTrajectory; const MeasurementContainer* measurements = nullptr; @@ -94,22 +95,23 @@ struct TrackStateCreator final } // TODO - void select(std::vector& candidates, bool& isOutlier, - const Acts::Logger& logger) const { - if (seed.has_value()) { - std::vector newCandidates; - - for (const auto& candidate : candidates) { - if (isSeedCandidate(candidate)) { - newCandidates.push_back(candidate); - } - } - - if (!newCandidates.empty()) { - candidates = std::move(newCandidates); - } - } - } + // void select(std::vector& candidates, bool& + // isOutlier, + // const Acts::Logger& logger) const { + // if (seed.has_value()) { + // std::vector newCandidates; + + // for (const auto& candidate : candidates) { + // if (isSeedCandidate(candidate)) { + // newCandidates.push_back(candidate); + // } + // } + + // if (!newCandidates.empty()) { + // candidates = std::move(newCandidates); + // } + // } + // } auto measurementRange(const Acts::Surface& surface, const Acts::BoundTrackParameters& boundParams) const { From 8ab5c28708833fc3386ba5bd34e43074ede87224 Mon Sep 17 00:00:00 2001 From: Andreas Stefl Date: Fri, 20 Mar 2026 16:32:25 +0100 Subject: [PATCH 5/6] use state --- .../Acts/TrackFinding/TrackStateCreator.hpp | 390 +++++++++--------- .../src/TrackFindingAlgorithm.cpp | 35 +- 2 files changed, 204 insertions(+), 221 deletions(-) diff --git a/Core/include/Acts/TrackFinding/TrackStateCreator.hpp b/Core/include/Acts/TrackFinding/TrackStateCreator.hpp index 4d17f23e8d9..60d698a1b45 100644 --- a/Core/include/Acts/TrackFinding/TrackStateCreator.hpp +++ b/Core/include/Acts/TrackFinding/TrackStateCreator.hpp @@ -44,15 +44,15 @@ class TrackStateCreatorBase { track_state_container_backend_t& trajectory, const Logger& logger) const { static_cast(trackStateCandidates); + auto state = derived().makeState(gctx, cctx, surface, boundState, prevTip, + trackStateCandidates, trajectory, logger); + TrackStatesResult result = TrackStatesResult::success({}); ACTS_VERBOSE("Perform measurement selection for surface " << surface.geometryId()); - const auto& [boundParams, jacobian, pathLength] = boundState; - - const auto measurementRange = - derived().measurementRange(surface, boundParams); + const auto measurementRange = derived().measurementRange(state); ACTS_VERBOSE("Found " << measurementRange.size() << " measurements on surface " << surface.geometryId()); @@ -62,42 +62,109 @@ class TrackStateCreatorBase { return result; } - const std::uint32_t maxNumSelectedMeasurements = - derived().getMaxNumSelectedMeasurements(surface); - const float maxChi2Compatible = derived().getMaxChi2Compatible(surface); - const float maxChi2Outlier = derived().getMaxChi2Outlier(surface); - - ACTS_VERBOSE("Max number of selected measurements: " - << maxNumSelectedMeasurements - << ", max chi2 compatible: " << maxChi2Compatible - << ", max chi2 outlier: " << maxChi2Outlier); - - auto selectedMeasurements = derived().selectMeasurements( - gctx, cctx, surface, measurementRange, boundParams, - maxNumSelectedMeasurements, maxChi2Compatible, maxChi2Outlier); + auto selectedMeasurements = + derived().selectMeasurements(state, measurementRange); ACTS_VERBOSE("Selected " << selectedMeasurements.size() << " measurements for surface " << surface.geometryId()); - derived().sortSelectedMeasurements(surface, selectedMeasurements, - maxNumSelectedMeasurements); + derived().sortSelectedMeasurements(state, selectedMeasurements); - derived().trimSelectedMeasurements(surface, selectedMeasurements, - maxNumSelectedMeasurements); + derived().trimSelectedMeasurements(state, selectedMeasurements); ACTS_VERBOSE("Trimmed to " << selectedMeasurements.size() << " measurements for surface " << surface.geometryId()); - derived().createTrackStates(gctx, cctx, surface, boundState, prevTip, - trajectory, measurementRange, - selectedMeasurements, logger); + derived().createTrackStatesImpl(state, measurementRange, + selectedMeasurements); return result; } protected: + template + struct State { + using TrackStateProxy = + typename track_state_container_backend_t::TrackStateProxy; + + State(const GeometryContext& gctx_, const CalibrationContext& cctx_, + const Surface& surface_, const BoundState& boundState_, + const TrackIndexType prevTip_, + std::vector& trackStateCandidates_, + track_state_container_backend_t& trajectory_, const Logger& logger_) + : gctx(&gctx_), + cctx(&cctx_), + surface(&surface_), + boundState(&boundState_), + prevTip(prevTip_), + trackStateCandidates(&trackStateCandidates_), + trajectory(&trajectory_), + logger(&logger_) {} + + const GeometryContext* gctx{}; + const CalibrationContext* cctx{}; + const Surface* surface{}; + const BoundState* boundState{}; + TrackIndexType prevTip{}; + std::vector* trackStateCandidates{}; + track_state_container_backend_t* trajectory{}; + const Logger* logger{}; + + std::uint32_t maxNumSelectedMeasurements{}; + float maxChi2Compatible{}; + float maxChi2Outlier{}; + + std::optional firstTrackState; + }; + + template + auto makeState(const GeometryContext& gctx, const CalibrationContext& cctx, + const Surface& surface, const BoundState& boundState, + const TrackIndexType prevTip, + std::vector& trackStateCandidates, + track_state_container_backend_t& trajectory, + const Logger& logger) const { + State state(gctx, cctx, surface, boundState, prevTip, trackStateCandidates, + trajectory, logger); + + state.maxNumSelectedMeasurements = + derived().getMaxNumSelectedMeasurements(state); + state.maxChi2Compatible = derived().getMaxChi2Compatible(state); + state.maxChi2Outlier = derived().getMaxChi2Outlier(state); + + ACTS_VERBOSE("Max number of selected measurements: " + << state.maxNumSelectedMeasurements + << ", max chi2 compatible: " << state.maxChi2Compatible + << ", max chi2 outlier: " << state.maxChi2Outlier); + + return state; + } + + template + std::uint32_t getMaxNumSelectedMeasurements(state_t& state) const { + static_cast(state); + + return std::numeric_limits::max(); + } + + template + float getMaxChi2Compatible(state_t& state) const { + static_cast(state); + + return std::numeric_limits::max(); + } + + template + float getMaxChi2Outlier(state_t& state) const { + static_cast(state); + + return std::numeric_limits::max(); + } + enum class MeasurementClassification : std::uint8_t { Incompatible, Outlier, @@ -152,82 +219,45 @@ class TrackStateCreatorBase { } }; - template + template SelectedMeasurements selectMeasurements( - const GeometryContext& gctx, const CalibrationContext& cctx, - const Surface& surface, const measurement_range_t& range, - const BoundTrackParameters& boundParams, - std::uint32_t maxNumSelectedMeasurements, float maxChi2Compatible, - float maxChi2Outlier) const { - static_cast(gctx); - static_cast(cctx); - static_cast(maxNumSelectedMeasurements); - + state_t& state, const measurement_range_t& range) const { SelectedMeasurements result; for (const auto& [i, measurement] : enumerate(range)) { - const float chi2 = - derived().computeChi2(surface, measurement, boundParams); + const float chi2 = derived().computeChi2(state, measurement); const MeasurementClassification classification = - derived().classifyMeasurement(surface, measurement, chi2, - maxChi2Compatible, maxChi2Outlier); + derived().classifyMeasurement(state, measurement, chi2); result.emplace_back(i, chi2, classification); } return result; } + template void sortSelectedMeasurements( - const Surface& surface, SelectedMeasurements& selectedMeasurements, - std::uint32_t maxNumSelectedMeasurements) const { - static_cast(surface); - - selectedMeasurements.sort(maxNumSelectedMeasurements); + state_t& state, SelectedMeasurements& selectedMeasurements) const { + selectedMeasurements.sort(state.maxNumSelectedMeasurements); } + template void trimSelectedMeasurements( - const Surface& surface, SelectedMeasurements& selectedMeasurements, - std::uint32_t maxNumSelectedMeasurements) const { - static_cast(surface); - - selectedMeasurements.resize(maxNumSelectedMeasurements); - } - - std::uint32_t getMaxNumSelectedMeasurements(const Surface& surface) const { - static_cast(surface); - - return std::numeric_limits::max(); + state_t& state, SelectedMeasurements& selectedMeasurements) const { + selectedMeasurements.resize(state.maxNumSelectedMeasurements); } - float getMaxChi2Compatible(const Surface& surface) const { - static_cast(surface); - - return std::numeric_limits::max(); - } - - float getMaxChi2Outlier(const Surface& surface) const { - static_cast(surface); - - return std::numeric_limits::max(); - } - - template - float computeChi2(const Surface& surface, const measurement_t& measurement, - const BoundTrackParameters& boundParams) const { - static_cast(surface); - + template + float computeChi2(state_t& state, const measurement_t& measurement) const { const auto& subspaceHelper = - derived().measurementSubspace(surface, measurement); - return derived().computeChi2(surface, measurement, boundParams, - subspaceHelper); + derived().measurementSubspace(state, measurement); + return derived().computeChi2Impl(state, measurement, subspaceHelper); } - template - double computeChi2(const Surface& surface, const measurement_t& measurement, - const BoundTrackParameters& boundParams, - FixedBoundSubspaceHelper subspaceHelper) const { - static_cast(surface); + template + double computeChi2Impl(state_t& state, const measurement_t& measurement, + FixedBoundSubspaceHelper subspaceHelper) const { + const auto& [boundParams, jacobian, pathLength] = *state.boundState; const Vector predictedParameters = subspaceHelper.projectVector(boundParams.parameters()); @@ -235,9 +265,9 @@ class TrackStateCreatorBase { subspaceHelper.projectMatrix(*boundParams.covariance()); const Vector measuredParameters = - derived().measuredParameters(surface, measurement); + derived().measuredParameters(state, measurement); const SquareMatrix measurementCovariance = - derived().measurementCovariance(surface, measurement); + derived().measurementCovariance(state, measurement); const Vector residualParameters = measuredParameters - predictedParameters; @@ -250,134 +280,115 @@ class TrackStateCreatorBase { return chi2; } - template - double computeChi2(const Surface& surface, const measurement_t& measurement, - const BoundTrackParameters& boundParams, - VariableBoundSubspaceHelper subspaceHelper) const { - static_cast(surface); - + template + double computeChi2Impl(state_t& state, const measurement_t& measurement, + VariableBoundSubspaceHelper subspaceHelper) const { return visit_measurement(subspaceHelper.size(), [&](auto N) -> double { constexpr std::size_t kDim = decltype(N)::value; const FixedBoundSubspaceHelper fixedSubspaceHelper(subspaceHelper); - return derived().computeChi2(surface, measurement, boundParams, - fixedSubspaceHelper); + return derived().computeChi2Impl(state, measurement, fixedSubspaceHelper); }); } - template + template MeasurementClassification classifyMeasurement( - const Surface& surface, const measurement_t& measurement, - const float chi2, const float maxChi2Compatible, - const float maxChi2Outlier) const { - static_cast(surface); + state_t& state, const measurement_t& measurement, + const float chi2) const { static_cast(measurement); - if (chi2 < maxChi2Compatible) { + if (chi2 < state.maxChi2Compatible) { return MeasurementClassification::Compatible; } - if (chi2 < maxChi2Outlier) { + if (chi2 < state.maxChi2Outlier) { return MeasurementClassification::Outlier; } return MeasurementClassification::Incompatible; } - template + template auto selectedMeasurementSourceLink( - const Surface& surface, const measurement_range_t& measurements, + state_t& state, const measurement_range_t& measurements, const MeasurementCandidate& selectedMeasurement) const { return derived().measurementSourceLink( - surface, measurements[selectedMeasurement.index]); + state, measurements[selectedMeasurement.index]); } - template + template auto selectedMeasurementSubspace( - const Surface& surface, const measurement_range_t& measurements, + state_t& state, const measurement_range_t& measurements, const MeasurementCandidate& selectedMeasurement) const { return derived().measurementSubspace( - surface, measurements[selectedMeasurement.index]); + state, measurements[selectedMeasurement.index]); } - template + template auto selectedMeasurementParameters( - const Surface& surface, const measurement_range_t& measurements, + state_t& state, const measurement_range_t& measurements, const MeasurementCandidate& selectedMeasurement) const { return derived().measuredParameters( - surface, measurements[selectedMeasurement.index]); + state, measurements[selectedMeasurement.index]); } - template + template auto selectedMeasurementCovariance( - const Surface& surface, const measurement_range_t& measurements, + state_t& state, const measurement_range_t& measurements, const MeasurementCandidate& selectedMeasurement) const { return derived().measurementCovariance( - surface, measurements[selectedMeasurement.index]); + state, measurements[selectedMeasurement.index]); } - template - void createTrackStates( - const GeometryContext& gctx, const CalibrationContext& cctx, - const Surface& surface, const BoundState& boundState, - const TrackIndexType prevTip, track_state_container_backend_t& trajectory, - const measurement_range_t& measurements, - const selected_measurement_range_t& selectedMeasurements, - const Logger& logger) const { - using TrackStateProxy = - typename track_state_container_backend_t::TrackStateProxy; - - std::optional firstTrackState; + template + void createTrackStatesImpl( + state_t& state, const measurement_range_t& measurements, + const selected_measurement_range_t& selectedMeasurements) const { + const Logger& logger = *state.logger; - if (derived().hasHole(surface, boundState, selectedMeasurements, logger)) { + if (derived().hasHole(state, selectedMeasurements)) { ACTS_VERBOSE("No compatible measurements on surface " - << surface.geometryId() << ". No track states created."); - derived().createHoleState(surface, boundState, prevTip, trajectory, - firstTrackState, logger); + << state.surface->geometryId() + << ". No track states created."); + derived().createHoleState(state); } for (const auto& selectedMeasurement : selectedMeasurements) { - derived().createMeasurementState( - gctx, cctx, surface, boundState, prevTip, trajectory, firstTrackState, - measurements, selectedMeasurement, logger); + derived().createMeasurementState(state, measurements, + selectedMeasurement); } } - bool hasHole(const Surface& surface, const BoundState& boundState, - const SelectedMeasurements& selectedMeasurements, - const Logger& logger) const { - static_cast(surface); - static_cast(boundState); - static_cast(logger); + template + bool hasHole(state_t& state, + const SelectedMeasurements& selectedMeasurements) const { + static_cast(state); return selectedMeasurements.empty() && (selectedMeasurements.bestCandidate.classification == MeasurementClassification::Incompatible); } - TrackStateType determineTrackStateType(const Surface& surface, - const Logger& logger) const { - static_cast(logger); - + template + TrackStateType determineTrackStateType(state_t& state) const { TrackStateType result; result.setHasParameters(); result.setHasMeasurement(); - if (surface.surfaceMaterial() != nullptr) { + if (state.surface->surfaceMaterial() != nullptr) { result.setHasMaterial(); } return result; } - template + template TrackStateType determineTrackStateType( - const Surface& surface, const measurement_range_t& measurements, - const MeasurementCandidate& selectedMeasurement, - const Logger& logger) const { + state_t& state, const measurement_range_t& measurements, + const MeasurementCandidate& selectedMeasurement) const { static_cast(measurements); - TrackStateType result = derived().determineTrackStateType(surface, logger); + TrackStateType result = derived().determineTrackStateType(state); if (selectedMeasurement.classification == MeasurementClassification::Outlier) { @@ -387,115 +398,88 @@ class TrackStateCreatorBase { return result; } - template - track_state_proxy_t createTrackState( - const Surface& surface, const BoundState& boundState, - const TrackIndexType prevTip, track_state_container_backend_t& trajectory, - std::optional& firstTrackState, - const TrackStateType trackStateType, const Logger& logger) const { - static_cast(logger); - + template + auto createTrackState(state_t& state, + const TrackStateType trackStateType) const { TrackStatePropMask mask = TrackStatePropMask::None; - if (!firstTrackState.has_value()) { + if (!state.firstTrackState.has_value()) { mask |= TrackStatePropMask::Predicted | TrackStatePropMask::Jacobian; } - track_state_proxy_t trackState = trajectory.makeTrackState(mask, prevTip); + auto trackState = state.trajectory->makeTrackState(mask, state.prevTip); - trackState.setReferenceSurface(surface.getSharedPtr()); + trackState.setReferenceSurface(state.surface->getSharedPtr()); trackState.typeFlags() = trackStateType; - if (!firstTrackState.has_value()) { - const auto& [boundParams, jacobian, pathLength] = boundState; + if (!state.firstTrackState.has_value()) { + const auto& [boundParams, jacobian, pathLength] = *state.boundState; trackState.predicted() = boundParams.parameters(); trackState.predictedCovariance() = *boundParams.covariance(); trackState.jacobian() = jacobian; trackState.pathLength() = pathLength; } else { - trackState.shareFrom(*firstTrackState, TrackStatePropMask::Predicted); - trackState.shareFrom(*firstTrackState, TrackStatePropMask::Jacobian); + trackState.shareFrom(*state.firstTrackState, + TrackStatePropMask::Predicted); + trackState.shareFrom(*state.firstTrackState, + TrackStatePropMask::Jacobian); } - if (!firstTrackState.has_value()) { - firstTrackState = trackState; + if (!state.firstTrackState.has_value()) { + state.firstTrackState = trackState; } return trackState; } - template - track_state_proxy_t createHoleState( - const Surface& surface, const BoundState& boundState, - const TrackIndexType prevTip, track_state_container_backend_t& trajectory, - std::optional& firstTrackState, - const Logger& logger) const { - TrackStateType trackStateType = - derived().determineTrackStateType(surface, logger); + template + auto createHoleState(state_t& state) const { + TrackStateType trackStateType = derived().determineTrackStateType(state); trackStateType.setIsHole(); - track_state_proxy_t trackState = - derived().createTrackState(surface, boundState, prevTip, trajectory, - firstTrackState, trackStateType, logger); + auto trackState = derived().createTrackState(state, trackStateType); return trackState; } - template - track_state_proxy_t createMeasurementState( - const GeometryContext& gctx, const CalibrationContext& cctx, - const Surface& surface, const BoundState& boundState, - const TrackIndexType prevTip, track_state_container_backend_t& trajectory, - std::optional& firstTrackState, - const measurement_range_t& measurements, - const selected_measurement_t& selectedMeasurement, - const Logger& logger) const { + auto createMeasurementState( + state_t& state, const measurement_range_t& measurements, + const selected_measurement_t& selectedMeasurement) const { const TrackStateType trackStateType = derived().determineTrackStateType( - surface, measurements, selectedMeasurement, logger); + state, measurements, selectedMeasurement); - track_state_proxy_t trackState = - derived().createTrackState(surface, boundState, prevTip, trajectory, - firstTrackState, trackStateType, logger); + auto trackState = derived().createTrackState(state, trackStateType); trackState.addComponents(TrackStatePropMask::Calibrated); - derived().postCalibrateTrackState(gctx, cctx, surface, boundState, - measurements, selectedMeasurement, + derived().postCalibrateTrackState(state, measurements, selectedMeasurement, trackState); return trackState; } - template void postCalibrateTrackState( - const GeometryContext& gctx, const CalibrationContext& cctx, - const Surface& surface, const BoundState& boundState, - const measurement_range_t& measurements, + state_t& state, const measurement_range_t& measurements, const selected_measurement_t& selectedMeasurement, - track_state_proxy_t& trackState) const { - static_cast(gctx); - static_cast(cctx); - static_cast(boundState); - + typename state_t::TrackStateProxy& trackState) const { trackState.setUncalibratedSourceLink( - derived().selectedMeasurementSourceLink(surface, measurements, + derived().selectedMeasurementSourceLink(state, measurements, selectedMeasurement)); const auto& subspaceHelper = derived().selectedMeasurementSubspace( - surface, measurements, selectedMeasurement); + state, measurements, selectedMeasurement); visit_measurement(subspaceHelper.size(), [&](auto N) -> void { constexpr std::size_t kDim = decltype(N)::value; const FixedBoundSubspaceHelper fixedSubspaceHelper(subspaceHelper); const Vector measuredParameters = - derived().selectedMeasurementParameters(surface, measurements, + derived().selectedMeasurementParameters(state, measurements, selectedMeasurement); const SquareMatrix measurementCovariance = - derived().selectedMeasurementCovariance(surface, measurements, + derived().selectedMeasurementCovariance(state, measurements, selectedMeasurement); trackState.allocateCalibrated(measuredParameters, measurementCovariance); }); diff --git a/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp b/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp index 552982f98ff..3277efa7c7f 100644 --- a/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp +++ b/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp @@ -72,12 +72,15 @@ namespace { struct TrackStateCreator final : public Acts::Experimental::TrackStateCreatorBase { using Base = Acts::Experimental::TrackStateCreatorBase; - using Traj = Acts::VectorMultiTrajectory; + using TrackStateContainer = Acts::VectorMultiTrajectory; + using TrackStateProxy = typename TrackStateContainer::TrackStateProxy; + using State = + typename Base::template State; const MeasurementContainer* measurements = nullptr; std::optional seed; - bool isSeedCandidate(const Traj::TrackStateProxy& candidate) const { + bool isSeedCandidate(const TrackStateProxy& candidate) const { assert(candidate.hasUncalibratedSourceLink()); assert(seed.has_value()); @@ -95,11 +98,10 @@ struct TrackStateCreator final } // TODO - // void select(std::vector& candidates, bool& - // isOutlier, + // void select(std::vector& candidates, bool& isOutlier, // const Acts::Logger& logger) const { // if (seed.has_value()) { - // std::vector newCandidates; + // std::vector newCandidates; // for (const auto& candidate : candidates) { // if (isSeedCandidate(candidate)) { @@ -113,40 +115,37 @@ struct TrackStateCreator final // } // } - auto measurementRange(const Acts::Surface& surface, - const Acts::BoundTrackParameters& boundParams) const { - static_cast(boundParams); - + auto measurementRange(State& state) const { const auto rangePair = - measurements->orderedIndices().equal_range(surface.geometryId()); + measurements->orderedIndices().equal_range(state.surface->geometryId()); return std::ranges::subrange(rangePair.first, rangePair.second); } Acts::SourceLink measurementSourceLink( - const Acts::Surface& surface, const IndexSourceLink& measurement) const { - static_cast(surface); + State& state, const IndexSourceLink& measurement) const { + static_cast(state); return Acts::SourceLink{measurement}; } Acts::VariableBoundSubspaceHelper measurementSubspace( - const Acts::Surface& surface, const IndexSourceLink& measurement) const { - static_cast(surface); + State& state, const IndexSourceLink& measurement) const { + static_cast(state); return measurements->at(measurement.index()).subspaceHelper(); } - auto measuredParameters(const Acts::Surface& surface, + auto measuredParameters(State& state, const IndexSourceLink& measurement) const { - static_cast(surface); + static_cast(state); return measurements->at(measurement.index()).parameters(); } - auto measurementCovariance(const Acts::Surface& surface, + auto measurementCovariance(State& state, const IndexSourceLink& measurement) const { - static_cast(surface); + static_cast(state); return measurements->at(measurement.index()).covariance(); } From 8863dac9067178f91a7b8393f7e7b377cce1439e Mon Sep 17 00:00:00 2001 From: Andreas Stefl Date: Fri, 20 Mar 2026 16:43:41 +0100 Subject: [PATCH 6/6] minor --- Core/include/Acts/TrackFinding/TrackStateCreator.hpp | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/Core/include/Acts/TrackFinding/TrackStateCreator.hpp b/Core/include/Acts/TrackFinding/TrackStateCreator.hpp index 60d698a1b45..189469f5c8d 100644 --- a/Core/include/Acts/TrackFinding/TrackStateCreator.hpp +++ b/Core/include/Acts/TrackFinding/TrackStateCreator.hpp @@ -69,8 +69,6 @@ class TrackStateCreatorBase { << " measurements for surface " << surface.geometryId()); - derived().sortSelectedMeasurements(state, selectedMeasurements); - derived().trimSelectedMeasurements(state, selectedMeasurements); ACTS_VERBOSE("Trimmed to " << selectedMeasurements.size() @@ -236,14 +234,10 @@ class TrackStateCreatorBase { } template - void sortSelectedMeasurements( + void trimSelectedMeasurements( state_t& state, SelectedMeasurements& selectedMeasurements) const { selectedMeasurements.sort(state.maxNumSelectedMeasurements); - } - template - void trimSelectedMeasurements( - state_t& state, SelectedMeasurements& selectedMeasurements) const { selectedMeasurements.resize(state.maxNumSelectedMeasurements); }