diff --git a/Core/include/Acts/TrackFinding/TrackStateCreator.hpp b/Core/include/Acts/TrackFinding/TrackStateCreator.hpp index 31410b0c246..189469f5c8d 100644 --- a/Core/include/Acts/TrackFinding/TrackStateCreator.hpp +++ b/Core/include/Acts/TrackFinding/TrackStateCreator.hpp @@ -8,12 +8,487 @@ #pragma once -// for definitions of Calibrator, MeasurementSelector +#include "Acts/EventData/BoundTrackParameters.hpp" +#include "Acts/EventData/SubspaceHelpers.hpp" +#include "Acts/EventData/TrackStateType.hpp" +#include "Acts/EventData/Types.hpp" +#include "Acts/Surfaces/Surface.hpp" #include "Acts/TrackFinding/CombinatorialKalmanFilterExtensions.hpp" #include "Acts/TrackFitting/KalmanFitter.hpp" +#include "Acts/Utilities/Enumerate.hpp" +#include "Acts/Utilities/Logger.hpp" namespace Acts { +namespace Experimental { + +// TODO surface token / detector element token +// TODO state object +// TODO cuts object +template +class TrackStateCreatorBase { + public: + /// Type alias for result of track states creation operation + using TrackStatesResult = Result>; + /// Type alias for bound state tuple containing parameters, jacobian and path + /// length + using BoundState = std::tuple; + + template + TrackStatesResult createTrackStates( + const GeometryContext& gctx, const CalibrationContext& cctx, + const Surface& surface, const BoundState& boundState, + const TrackIndexType prevTip, + std::vector& trackStateCandidates, + track_state_container_backend_t& trajectory, const Logger& logger) const { + static_cast(trackStateCandidates); + + auto state = derived().makeState(gctx, cctx, surface, boundState, prevTip, + trackStateCandidates, trajectory, logger); + + TrackStatesResult result = TrackStatesResult::success({}); + + ACTS_VERBOSE("Perform measurement selection for surface " + << surface.geometryId()); + + const auto measurementRange = derived().measurementRange(state); + ACTS_VERBOSE("Found " << measurementRange.size() + << " measurements on surface " + << surface.geometryId()); + if (measurementRange.begin() == measurementRange.end()) { + ACTS_VERBOSE("No measurements on surface " + << surface.geometryId() << ". No track states created."); + return result; + } + + auto selectedMeasurements = + derived().selectMeasurements(state, measurementRange); + + ACTS_VERBOSE("Selected " << selectedMeasurements.size() + << " measurements for surface " + << surface.geometryId()); + + derived().trimSelectedMeasurements(state, selectedMeasurements); + + ACTS_VERBOSE("Trimmed to " << selectedMeasurements.size() + << " measurements for surface " + << surface.geometryId()); + + derived().createTrackStatesImpl(state, measurementRange, + selectedMeasurements); + + return result; + } + + protected: + template + struct State { + using TrackStateProxy = + typename track_state_container_backend_t::TrackStateProxy; + + State(const GeometryContext& gctx_, const CalibrationContext& cctx_, + const Surface& surface_, const BoundState& boundState_, + const TrackIndexType prevTip_, + std::vector& trackStateCandidates_, + track_state_container_backend_t& trajectory_, const Logger& logger_) + : gctx(&gctx_), + cctx(&cctx_), + surface(&surface_), + boundState(&boundState_), + prevTip(prevTip_), + trackStateCandidates(&trackStateCandidates_), + trajectory(&trajectory_), + logger(&logger_) {} + + const GeometryContext* gctx{}; + const CalibrationContext* cctx{}; + const Surface* surface{}; + const BoundState* boundState{}; + TrackIndexType prevTip{}; + std::vector* trackStateCandidates{}; + track_state_container_backend_t* trajectory{}; + const Logger* logger{}; + + std::uint32_t maxNumSelectedMeasurements{}; + float maxChi2Compatible{}; + float maxChi2Outlier{}; + + std::optional firstTrackState; + }; + + template + auto makeState(const GeometryContext& gctx, const CalibrationContext& cctx, + const Surface& surface, const BoundState& boundState, + const TrackIndexType prevTip, + std::vector& trackStateCandidates, + track_state_container_backend_t& trajectory, + const Logger& logger) const { + State state(gctx, cctx, surface, boundState, prevTip, trackStateCandidates, + trajectory, logger); + + state.maxNumSelectedMeasurements = + derived().getMaxNumSelectedMeasurements(state); + state.maxChi2Compatible = derived().getMaxChi2Compatible(state); + state.maxChi2Outlier = derived().getMaxChi2Outlier(state); + + ACTS_VERBOSE("Max number of selected measurements: " + << state.maxNumSelectedMeasurements + << ", max chi2 compatible: " << state.maxChi2Compatible + << ", max chi2 outlier: " << state.maxChi2Outlier); + + return state; + } + + template + std::uint32_t getMaxNumSelectedMeasurements(state_t& state) const { + static_cast(state); + + return std::numeric_limits::max(); + } + + template + float getMaxChi2Compatible(state_t& state) const { + static_cast(state); + + return std::numeric_limits::max(); + } + + template + float getMaxChi2Outlier(state_t& state) const { + static_cast(state); + + return std::numeric_limits::max(); + } + + enum class MeasurementClassification : std::uint8_t { + Incompatible, + Outlier, + Compatible + }; + + struct MeasurementCandidate { + std::uint32_t index{std::numeric_limits::max()}; + float chi2{std::numeric_limits::max()}; + MeasurementClassification classification{ + MeasurementClassification::Incompatible}; + }; + + struct SelectedMeasurements { + MeasurementCandidate bestCandidate; + CkfTypes::BranchVector nonIncompatible; + + std::size_t size() const { return nonIncompatible.size(); } + bool empty() const { return nonIncompatible.empty(); } + + auto begin() const { return nonIncompatible.begin(); } + auto end() const { return nonIncompatible.end(); } + + void push_back(const MeasurementCandidate& candidate) { + if (candidate.chi2 < bestCandidate.chi2) { + bestCandidate = candidate; + } + if (candidate.classification != MeasurementClassification::Incompatible) { + nonIncompatible.push_back(candidate); + } + } + + void emplace_back(std::uint32_t index, float chi2, + MeasurementClassification classification) { + push_back({index, chi2, classification}); + } + + void sort(std::uint32_t maxNumSelectedMeasurements) { + if (maxNumSelectedMeasurements < nonIncompatible.size()) { + nonIncompatible.resize(maxNumSelectedMeasurements); + } + + std::ranges::sort( + nonIncompatible, {}, + [](const MeasurementCandidate& candidate) { return candidate.chi2; }); + } + + void resize(std::uint32_t maxNumSelectedMeasurements) { + if (maxNumSelectedMeasurements < nonIncompatible.size()) { + nonIncompatible.resize(maxNumSelectedMeasurements); + } + } + }; + + template + SelectedMeasurements selectMeasurements( + state_t& state, const measurement_range_t& range) const { + SelectedMeasurements result; + + for (const auto& [i, measurement] : + enumerate(range)) { + const float chi2 = derived().computeChi2(state, measurement); + const MeasurementClassification classification = + derived().classifyMeasurement(state, measurement, chi2); + result.emplace_back(i, chi2, classification); + } + + return result; + } + + template + void trimSelectedMeasurements( + state_t& state, SelectedMeasurements& selectedMeasurements) const { + selectedMeasurements.sort(state.maxNumSelectedMeasurements); + + selectedMeasurements.resize(state.maxNumSelectedMeasurements); + } + + template + float computeChi2(state_t& state, const measurement_t& measurement) const { + const auto& subspaceHelper = + derived().measurementSubspace(state, measurement); + return derived().computeChi2Impl(state, measurement, subspaceHelper); + } + + template + double computeChi2Impl(state_t& state, const measurement_t& measurement, + FixedBoundSubspaceHelper subspaceHelper) const { + const auto& [boundParams, jacobian, pathLength] = *state.boundState; + + const Vector predictedParameters = + subspaceHelper.projectVector(boundParams.parameters()); + const SquareMatrix predictedCovariance = + subspaceHelper.projectMatrix(*boundParams.covariance()); + + const Vector measuredParameters = + derived().measuredParameters(state, measurement); + const SquareMatrix measurementCovariance = + derived().measurementCovariance(state, measurement); + + const Vector residualParameters = + measuredParameters - predictedParameters; + const SquareMatrix residualCovariance = + predictedCovariance + measurementCovariance; + + const double chi2 = (residualParameters.transpose() * + residualCovariance.inverse() * residualParameters) + .eval()(0, 0); + return chi2; + } + + template + double computeChi2Impl(state_t& state, const measurement_t& measurement, + VariableBoundSubspaceHelper subspaceHelper) const { + return visit_measurement(subspaceHelper.size(), [&](auto N) -> double { + constexpr std::size_t kDim = decltype(N)::value; + const FixedBoundSubspaceHelper fixedSubspaceHelper(subspaceHelper); + return derived().computeChi2Impl(state, measurement, fixedSubspaceHelper); + }); + } + + template + MeasurementClassification classifyMeasurement( + state_t& state, const measurement_t& measurement, + const float chi2) const { + static_cast(measurement); + + if (chi2 < state.maxChi2Compatible) { + return MeasurementClassification::Compatible; + } + if (chi2 < state.maxChi2Outlier) { + return MeasurementClassification::Outlier; + } + return MeasurementClassification::Incompatible; + } + + template + auto selectedMeasurementSourceLink( + state_t& state, const measurement_range_t& measurements, + const MeasurementCandidate& selectedMeasurement) const { + return derived().measurementSourceLink( + state, measurements[selectedMeasurement.index]); + } + + template + auto selectedMeasurementSubspace( + state_t& state, const measurement_range_t& measurements, + const MeasurementCandidate& selectedMeasurement) const { + return derived().measurementSubspace( + state, measurements[selectedMeasurement.index]); + } + + template + auto selectedMeasurementParameters( + state_t& state, const measurement_range_t& measurements, + const MeasurementCandidate& selectedMeasurement) const { + return derived().measuredParameters( + state, measurements[selectedMeasurement.index]); + } + + template + auto selectedMeasurementCovariance( + state_t& state, const measurement_range_t& measurements, + const MeasurementCandidate& selectedMeasurement) const { + return derived().measurementCovariance( + state, measurements[selectedMeasurement.index]); + } + + template + void createTrackStatesImpl( + state_t& state, const measurement_range_t& measurements, + const selected_measurement_range_t& selectedMeasurements) const { + const Logger& logger = *state.logger; + + if (derived().hasHole(state, selectedMeasurements)) { + ACTS_VERBOSE("No compatible measurements on surface " + << state.surface->geometryId() + << ". No track states created."); + derived().createHoleState(state); + } + + for (const auto& selectedMeasurement : selectedMeasurements) { + derived().createMeasurementState(state, measurements, + selectedMeasurement); + } + } + + template + bool hasHole(state_t& state, + const SelectedMeasurements& selectedMeasurements) const { + static_cast(state); + + return selectedMeasurements.empty() && + (selectedMeasurements.bestCandidate.classification == + MeasurementClassification::Incompatible); + } + + template + TrackStateType determineTrackStateType(state_t& state) const { + TrackStateType result; + + result.setHasParameters(); + + result.setHasMeasurement(); + + if (state.surface->surfaceMaterial() != nullptr) { + result.setHasMaterial(); + } + + return result; + } + + template + TrackStateType determineTrackStateType( + state_t& state, const measurement_range_t& measurements, + const MeasurementCandidate& selectedMeasurement) const { + static_cast(measurements); + + TrackStateType result = derived().determineTrackStateType(state); + + if (selectedMeasurement.classification == + MeasurementClassification::Outlier) { + result.setIsOutlier(); + } + + return result; + } + + template + auto createTrackState(state_t& state, + const TrackStateType trackStateType) const { + TrackStatePropMask mask = TrackStatePropMask::None; + if (!state.firstTrackState.has_value()) { + mask |= TrackStatePropMask::Predicted | TrackStatePropMask::Jacobian; + } + + auto trackState = state.trajectory->makeTrackState(mask, state.prevTip); + + trackState.setReferenceSurface(state.surface->getSharedPtr()); + + trackState.typeFlags() = trackStateType; + + if (!state.firstTrackState.has_value()) { + const auto& [boundParams, jacobian, pathLength] = *state.boundState; + trackState.predicted() = boundParams.parameters(); + trackState.predictedCovariance() = *boundParams.covariance(); + trackState.jacobian() = jacobian; + trackState.pathLength() = pathLength; + } else { + trackState.shareFrom(*state.firstTrackState, + TrackStatePropMask::Predicted); + trackState.shareFrom(*state.firstTrackState, + TrackStatePropMask::Jacobian); + } + + if (!state.firstTrackState.has_value()) { + state.firstTrackState = trackState; + } + return trackState; + } + + template + auto createHoleState(state_t& state) const { + TrackStateType trackStateType = derived().determineTrackStateType(state); + trackStateType.setIsHole(); + + auto trackState = derived().createTrackState(state, trackStateType); + + return trackState; + } + + template + auto createMeasurementState( + state_t& state, const measurement_range_t& measurements, + const selected_measurement_t& selectedMeasurement) const { + const TrackStateType trackStateType = derived().determineTrackStateType( + state, measurements, selectedMeasurement); + + auto trackState = derived().createTrackState(state, trackStateType); + + trackState.addComponents(TrackStatePropMask::Calibrated); + + derived().postCalibrateTrackState(state, measurements, selectedMeasurement, + trackState); + + return trackState; + } + + template + void postCalibrateTrackState( + state_t& state, const measurement_range_t& measurements, + const selected_measurement_t& selectedMeasurement, + typename state_t::TrackStateProxy& trackState) const { + trackState.setUncalibratedSourceLink( + derived().selectedMeasurementSourceLink(state, measurements, + selectedMeasurement)); + + const auto& subspaceHelper = derived().selectedMeasurementSubspace( + state, measurements, selectedMeasurement); + + visit_measurement(subspaceHelper.size(), [&](auto N) -> void { + constexpr std::size_t kDim = decltype(N)::value; + const FixedBoundSubspaceHelper fixedSubspaceHelper(subspaceHelper); + const Vector measuredParameters = + derived().selectedMeasurementParameters(state, measurements, + selectedMeasurement); + const SquareMatrix measurementCovariance = + derived().selectedMeasurementCovariance(state, measurements, + selectedMeasurement); + trackState.allocateCalibrated(measuredParameters, measurementCovariance); + }); + + trackState.setProjectorSubspaceIndices(subspaceHelper.indices()); + } + + private: + const derived_t& derived() const { + return static_cast(*this); + } +}; + +} // namespace Experimental + /// @brief Create track states for selected measurements associated to a surface. /// /// - First get a source link range covering relevant measurements associated to @@ -35,8 +510,7 @@ namespace Acts { template struct TrackStateCreator { /// Type alias for result of track states creation operation - using TrackStatesResult = - Acts::Result>; + using TrackStatesResult = Result>; /// Type alias for track state container backend from track container using TrackStateContainerBackend = typename track_container_t::TrackStateContainerBackend; @@ -84,7 +558,6 @@ struct TrackStateCreator { MeasurementSelector measurementSelector{ DelegateFuncTag{}}; - public: /// @brief extend the trajectory onto the given surface. /// /// @param gctx The geometry context to be used for this task @@ -146,8 +619,7 @@ struct TrackStateCreator { TrackStateContainerBackend& trajectory, const Logger& logger) const { using PM = TrackStatePropMask; - using ResultTrackStateList = - Acts::Result>; + using ResultTrackStateList = Result>; ResultTrackStateList resultTrackStateList{ CkfTypes::BranchVector()}; const auto& [boundParams, jacobian, pathLength] = boundState; @@ -237,8 +709,7 @@ struct TrackStateCreator { const Logger& logger) const { using PM = TrackStatePropMask; - using ResultTrackStateList = - Acts::Result>; + using ResultTrackStateList = Result>; ResultTrackStateList resultTrackStateList{ CkfTypes::BranchVector()}; CkfTypes::BranchVector& trackStateList = diff --git a/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp b/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp index 4b80a458b70..3277efa7c7f 100644 --- a/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp +++ b/Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp @@ -11,9 +11,11 @@ #include "Acts/Definitions/Algebra.hpp" #include "Acts/Definitions/Direction.hpp" #include "Acts/Definitions/TrackParametrization.hpp" +#include "Acts/EventData/BoundTrackParameters.hpp" #include "Acts/EventData/MultiTrajectory.hpp" #include "Acts/EventData/ProxyAccessor.hpp" #include "Acts/EventData/SourceLink.hpp" +#include "Acts/EventData/SubspaceHelpers.hpp" #include "Acts/EventData/TrackContainer.hpp" #include "Acts/EventData/VectorMultiTrajectory.hpp" #include "Acts/EventData/VectorTrackContainer.hpp" @@ -67,48 +69,24 @@ namespace ActsExamples { namespace { -class MeasurementSelector { - public: - using Traj = Acts::VectorMultiTrajectory; - - explicit MeasurementSelector(Acts::MeasurementSelector selector) - : m_selector(std::move(selector)) {} - - void setSeed(const std::optional& seed) { m_seed = seed; } - - Acts::Result::iterator, - std::vector::iterator>> - select(std::vector& candidates, bool& isOutlier, - const Acts::Logger& logger) const { - if (m_seed.has_value()) { - std::vector newCandidates; - - for (const auto& candidate : candidates) { - if (isSeedCandidate(candidate)) { - newCandidates.push_back(candidate); - } - } +struct TrackStateCreator final + : public Acts::Experimental::TrackStateCreatorBase { + using Base = Acts::Experimental::TrackStateCreatorBase; + using TrackStateContainer = Acts::VectorMultiTrajectory; + using TrackStateProxy = typename TrackStateContainer::TrackStateProxy; + using State = + typename Base::template State; - if (!newCandidates.empty()) { - candidates = std::move(newCandidates); - } - } - - return m_selector.select(candidates, isOutlier, - logger); - } + const MeasurementContainer* measurements = nullptr; + std::optional seed; - private: - Acts::MeasurementSelector m_selector; - std::optional m_seed; - - bool isSeedCandidate(const Traj::TrackStateProxy& candidate) const { + bool isSeedCandidate(const TrackStateProxy& candidate) const { assert(candidate.hasUncalibratedSourceLink()); - assert(m_seed.has_value()); + assert(seed.has_value()); const Acts::SourceLink& sourceLink = candidate.getUncalibratedSourceLink(); - for (const ConstSpacePointProxy sp : m_seed->spacePoints()) { + for (const ConstSpacePointProxy sp : seed->spacePoints()) { for (const Acts::SourceLink& sl : sp.sourceLinks()) { if (sourceLink.get() == sl.get()) { return true; @@ -118,6 +96,59 @@ class MeasurementSelector { return false; } + + // TODO + // void select(std::vector& candidates, bool& isOutlier, + // const Acts::Logger& logger) const { + // if (seed.has_value()) { + // std::vector newCandidates; + + // for (const auto& candidate : candidates) { + // if (isSeedCandidate(candidate)) { + // newCandidates.push_back(candidate); + // } + // } + + // if (!newCandidates.empty()) { + // candidates = std::move(newCandidates); + // } + // } + // } + + auto measurementRange(State& state) const { + const auto rangePair = + measurements->orderedIndices().equal_range(state.surface->geometryId()); + + return std::ranges::subrange(rangePair.first, rangePair.second); + } + + Acts::SourceLink measurementSourceLink( + State& state, const IndexSourceLink& measurement) const { + static_cast(state); + + return Acts::SourceLink{measurement}; + } + + Acts::VariableBoundSubspaceHelper measurementSubspace( + State& state, const IndexSourceLink& measurement) const { + static_cast(state); + + return measurements->at(measurement.index()).subspaceHelper(); + } + + auto measuredParameters(State& state, + const IndexSourceLink& measurement) const { + static_cast(state); + + return measurements->at(measurement.index()).parameters(); + } + + auto measurementCovariance(State& state, + const IndexSourceLink& measurement) const { + static_cast(state); + + return measurements->at(measurement.index()).covariance(); + } }; /// Source link indices of the bottom, middle, top measurements. @@ -321,29 +352,18 @@ ProcessCode TrackFindingAlgorithm::execute(const AlgorithmContext& ctx) const { using Extensions = Acts::CombinatorialKalmanFilterExtensions; BranchStopper branchStopper(m_cfg); - MeasurementSelector measSel{ - Acts::MeasurementSelector(m_cfg.measurementSelectorCfg)}; - - IndexSourceLinkAccessor slAccessor; - slAccessor.container = &measurements.orderedIndices(); - - using TrackStateCreatorType = - Acts::TrackStateCreator; - TrackStateCreatorType trackStateCreator; - trackStateCreator.sourceLinkAccessor - .template connect<&IndexSourceLinkAccessor::range>(&slAccessor); - trackStateCreator.calibrator - .template connect<&MeasurementCalibratorAdapter::calibrate>(&calibrator); - trackStateCreator.measurementSelector - .template connect<&MeasurementSelector::select>(&measSel); + + TrackStateCreator trackStateCreator; + trackStateCreator.measurements = &measurements; Extensions extensions; extensions.updater.connect<&Acts::GainMatrixUpdater::operator()< typename TrackContainer::TrackStateContainerBackend>>(&kfUpdater); extensions.branchStopper.connect<&BranchStopper::operator()>(&branchStopper); extensions.createTrackStates - .template connect<&TrackStateCreatorType ::createTrackStates>( + .template connect<&TrackStateCreator::createTrackStates< + TrackStateProxy, + typename TrackContainer::TrackStateContainerBackend>>( &trackStateCreator); Acts::PropagatorPlainOptions firstPropOptions(ctx.geoContext, @@ -469,7 +489,7 @@ ProcessCode TrackFindingAlgorithm::execute(const AlgorithmContext& ctx) const { } if (m_cfg.stayOnSeed) { - measSel.setSeed(seed); + trackStateCreator.seed = seed; } }