From 983ca75aede9026334e70af933c269af68f15869 Mon Sep 17 00:00:00 2001 From: Zhuo Wang Date: Wed, 10 Jun 2026 11:29:09 +0800 Subject: [PATCH] feat: Collect Parquet NaN metrics during writes --- src/iceberg/parquet/parquet_writer.cc | 157 ++++++++++++++++++++++- src/iceberg/test/metrics_test_base.cc | 20 +-- src/iceberg/test/metrics_test_base.h | 3 + src/iceberg/test/parquet_metrics_test.cc | 1 + 4 files changed, 171 insertions(+), 10 deletions(-) diff --git a/src/iceberg/parquet/parquet_writer.cc b/src/iceberg/parquet/parquet_writer.cc index da794cc3e..c50fb26b1 100644 --- a/src/iceberg/parquet/parquet_writer.cc +++ b/src/iceberg/parquet/parquet_writer.cc @@ -19,10 +19,16 @@ #include "iceberg/parquet/parquet_writer.h" +#include +#include #include +#include #include +#include +#include #include +#include #include #include #include @@ -36,7 +42,9 @@ #include "iceberg/arrow/arrow_status_internal.h" #include "iceberg/parquet/parquet_metrics_internal.h" #include "iceberg/schema_internal.h" +#include "iceberg/type.h" #include "iceberg/util/macros.h" +#include "iceberg/util/visit_type.h" namespace iceberg::parquet { @@ -74,6 +82,144 @@ Status CheckCompressionAvailable(std::string_view compression_name, return {}; } +template +Status UpdateFloatingFieldMetrics(int32_t field_id, const ::arrow::Array& arrow_array, + const std::vector* valid_rows, + std::unordered_map& metrics) { + constexpr auto expected_type_id = + std::is_same_v ? ::arrow::Type::FLOAT : ::arrow::Type::DOUBLE; + ICEBERG_PRECHECK(arrow_array.type_id() == expected_type_id, + "Expected Arrow floating-point array for field metrics collection"); + const auto& array = static_cast(arrow_array); + auto& field_metrics = metrics[field_id]; + field_metrics.field_id = field_id; + if (field_metrics.value_count < 0) { + field_metrics.value_count = 0; + } + if (field_metrics.null_value_count < 0) { + field_metrics.null_value_count = 0; + } + if (field_metrics.nan_value_count < 0) { + field_metrics.nan_value_count = 0; + } + + field_metrics.value_count += array.length(); + + for (int64_t i = 0; i < array.length(); ++i) { + if ((valid_rows != nullptr && (*valid_rows)[i] == 0) || array.IsNull(i)) { + ++field_metrics.null_value_count; + continue; + } + + ValueType value = array.Value(i); + if (std::isnan(value)) { + ++field_metrics.nan_value_count; + continue; + } + + auto literal = [&]() { + if constexpr (std::is_same_v) { + return Literal::Float(value); + } else { + return Literal::Double(value); + } + }(); + if (!field_metrics.lower_bound.has_value() || + literal < field_metrics.lower_bound.value()) { + field_metrics.lower_bound = literal; + } + if (!field_metrics.upper_bound.has_value() || + literal > field_metrics.upper_bound.value()) { + field_metrics.upper_bound = std::move(literal); + } + } + + return {}; +} + +std::optional> BuildValidRows(const ::arrow::Array& array, + const std::vector* parent) { + if (parent == nullptr && array.null_count() == 0) { + return std::nullopt; + } + + std::vector valid_rows(array.length(), 1); + for (int64_t i = 0; i < array.length(); ++i) { + if ((parent != nullptr && (*parent)[i] == 0) || array.IsNull(i)) { + valid_rows[i] = 0; + } + } + return valid_rows; +} + +class FieldMetricsCollector { + public: + FieldMetricsCollector(std::unordered_map& metrics, + const MetricsConfig& metrics_config, const Schema& schema) + : metrics_(metrics), metrics_config_(metrics_config), schema_(schema) {} + + Status VisitStruct(const StructType& type, const ::arrow::Array& array) { + ICEBERG_PRECHECK(array.type_id() == ::arrow::Type::STRUCT, + "Expected Arrow struct array for Iceberg struct metrics collection"); + const auto& struct_array = static_cast(array); + ICEBERG_PRECHECK( + struct_array.num_fields() == type.fields().size(), + "Arrow struct field count does not match Iceberg struct field count"); + + for (int i = 0; i < struct_array.num_fields(); ++i) { + ICEBERG_RETURN_UNEXPECTED(VisitField(type.fields()[i], *struct_array.field(i))); + } + return {}; + } + + Status VisitList(const ListType& /*type*/, const ::arrow::Array& /*array*/) { + return {}; + } + + Status VisitMap(const MapType& /*type*/, const ::arrow::Array& /*array*/) { return {}; } + + Status VisitPrimitive(const PrimitiveType& type, const ::arrow::Array& array) { + switch (type.type_id()) { + case TypeId::kFloat: + return UpdateFloatingFieldMetrics<::arrow::FloatArray, float>( + field_id_, array, valid_rows_, metrics_); + case TypeId::kDouble: + return UpdateFloatingFieldMetrics<::arrow::DoubleArray, double>( + field_id_, array, valid_rows_, metrics_); + default: + return {}; + } + } + + private: + Status VisitField(const SchemaField& field, const ::arrow::Array& array) { + // Skip metrics collection for fields whose mode is kNone in MetricsConfig. + ICEBERG_ASSIGN_OR_RAISE(auto column_name, + schema_.FindColumnNameById(field.field_id())); + if (column_name.has_value() && metrics_config_.ColumnMode(column_name.value()).kind == + MetricsMode::Kind::kNone) { + return {}; + } + + auto previous_valid_rows = valid_rows_; + auto field_valid_rows = BuildValidRows(array, previous_valid_rows); + if (field_valid_rows.has_value()) { + valid_rows_ = &field_valid_rows.value(); + } + + field_id_ = field.field_id(); + auto status = VisitTypeCategory(*field.type(), this, array); + valid_rows_ = previous_valid_rows; + return status; + } + + std::unordered_map& metrics_; + const MetricsConfig& metrics_config_; + const Schema& schema_; + const std::vector* valid_rows_ = nullptr; + int32_t field_id_ = -1; +}; + Result> ParseCodecLevel(const WriterProperties& properties) { auto level_str = properties.Get(WriterProperties::kParquetCompressionLevel); if (level_str.empty()) { @@ -136,6 +282,12 @@ class ParquetWriter::Impl { ICEBERG_ARROW_ASSIGN_OR_RETURN(auto batch, ::arrow::ImportRecordBatch(array, arrow_schema_)); + ICEBERG_ARROW_ASSIGN_OR_RETURN(auto struct_array, batch->ToStructArray()); + FieldMetricsCollector field_metrics_collector(field_metrics_, *metrics_config_, + *schema_); + ICEBERG_RETURN_UNEXPECTED( + field_metrics_collector.VisitStruct(*schema_, *struct_array)); + ICEBERG_ARROW_RETURN_NOT_OK(writer_->WriteRecordBatch(*batch)); return {}; @@ -179,9 +331,8 @@ class ParquetWriter::Impl { ICEBERG_PRECHECK(writer_ == nullptr, "Cannot return metrics for unclosed writer"); ICEBERG_PRECHECK(metadata_ != nullptr, "Cannot return metrics because Parquet metadata is not available"); - // TODO(WZhuo): collect write-side FieldMetrics to support NaN value counts. return ParquetMetrics::GetMetrics(*schema_, *parquet_schema_, *metrics_config_, - *metadata_); + *metadata_, field_metrics_); } private: @@ -205,6 +356,8 @@ class ParquetWriter::Impl { int64_t total_bytes_{0}; // Row group start offsets in the Parquet file. std::vector split_offsets_; + // Write-side metrics for fields whose Parquet footer metrics are incomplete. + std::unordered_map field_metrics_; }; ParquetWriter::~ParquetWriter() = default; diff --git a/src/iceberg/test/metrics_test_base.cc b/src/iceberg/test/metrics_test_base.cc index 7913a7207..92e89e8e4 100644 --- a/src/iceberg/test/metrics_test_base.cc +++ b/src/iceberg/test/metrics_test_base.cc @@ -541,8 +541,9 @@ void MetricsTestBase::MetricsForNaNColumns() { ASSERT_TRUE(metrics.row_count.has_value()) << "row_count should be set"; EXPECT_EQ(*metrics.row_count, 2); - AssertCounts(1, 2, 0, metrics); - AssertCounts(2, 2, 0, metrics); + auto expected_nan_count = ReportsNanCounts() ? std::optional(2) : std::nullopt; + AssertCounts(1, 2, 0, expected_nan_count, metrics); + AssertCounts(2, 2, 0, expected_nan_count, metrics); // When all values are NaN, bounds should not be set AssertBounds(1, float32(), std::nullopt, std::nullopt, metrics); @@ -579,8 +580,9 @@ void MetricsTestBase::ColumnBoundsWithNaNValueAtFront() { ASSERT_TRUE(metrics.row_count.has_value()) << "row_count should be set"; EXPECT_EQ(*metrics.row_count, 3); - AssertCounts(1, 3, 0, metrics); - AssertCounts(2, 3, 0, metrics); + auto expected_nan_count = ReportsNanCounts() ? std::optional(1) : std::nullopt; + AssertCounts(1, 3, 0, expected_nan_count, metrics); + AssertCounts(2, 3, 0, expected_nan_count, metrics); // Bounds should be computed from non-NaN values if (metrics.lower_bounds.contains(1)) { @@ -619,8 +621,9 @@ void MetricsTestBase::ColumnBoundsWithNaNValueInMiddle() { ASSERT_TRUE(metrics.row_count.has_value()) << "row_count should be set"; EXPECT_EQ(*metrics.row_count, 3); - AssertCounts(1, 3, 0, metrics); - AssertCounts(2, 3, 0, metrics); + auto expected_nan_count = ReportsNanCounts() ? std::optional(1) : std::nullopt; + AssertCounts(1, 3, 0, expected_nan_count, metrics); + AssertCounts(2, 3, 0, expected_nan_count, metrics); if (metrics.lower_bounds.contains(1)) { AssertBounds(1, float32(), 1.2F, 5.6F, metrics); @@ -658,8 +661,9 @@ void MetricsTestBase::ColumnBoundsWithNaNValueAtEnd() { ASSERT_TRUE(metrics.row_count.has_value()) << "row_count should be set"; EXPECT_EQ(*metrics.row_count, 3); - AssertCounts(1, 3, 0, metrics); - AssertCounts(2, 3, 0, metrics); + auto expected_nan_count = ReportsNanCounts() ? std::optional(1) : std::nullopt; + AssertCounts(1, 3, 0, expected_nan_count, metrics); + AssertCounts(2, 3, 0, expected_nan_count, metrics); if (metrics.lower_bounds.contains(1)) { AssertBounds(1, float32(), 1.2F, 5.6F, metrics); diff --git a/src/iceberg/test/metrics_test_base.h b/src/iceberg/test/metrics_test_base.h index 07b3b62f3..530ef3b5e 100644 --- a/src/iceberg/test/metrics_test_base.h +++ b/src/iceberg/test/metrics_test_base.h @@ -60,6 +60,9 @@ class MetricsTestBase { /// \brief Whether the format supports small row groups for testing virtual bool SupportsSmallRowGroups() const { return false; } + /// \brief Whether the format reports NaN counts for floating-point fields + virtual bool ReportsNanCounts() const { return false; } + // Helper methods for assertions void AssertCounts(int field_id, std::optional expected_value_count, std::optional expected_null_count, const Metrics& metrics); diff --git a/src/iceberg/test/parquet_metrics_test.cc b/src/iceberg/test/parquet_metrics_test.cc index 8286efe08..6afe6f9c8 100644 --- a/src/iceberg/test/parquet_metrics_test.cc +++ b/src/iceberg/test/parquet_metrics_test.cc @@ -106,6 +106,7 @@ class ParquetMetricsTest : public MetricsTestBase, public ::testing::Test { } bool SupportsSmallRowGroups() const override { return true; } + bool ReportsNanCounts() const override { return true; } private: std::string temp_parquet_file_;