Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 155 additions & 2 deletions src/iceberg/parquet/parquet_writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,16 @@

#include "iceberg/parquet/parquet_writer.h"

#include <cmath>
#include <cstdint>
#include <memory>
#include <optional>
#include <string_view>
#include <type_traits>
#include <unordered_map>
#include <vector>

#include <arrow/array.h>
#include <arrow/c/bridge.h>
#include <arrow/record_batch.h>
#include <arrow/util/compression.h>
Expand All @@ -36,7 +42,9 @@
#include "iceberg/arrow/arrow_status_internal.h"
#include "iceberg/parquet/parquet_metrics_internal.h"
#include "iceberg/schema_internal.h"
#include "iceberg/type.h"
#include "iceberg/util/macros.h"
#include "iceberg/util/visit_type.h"

namespace iceberg::parquet {

Expand Down Expand Up @@ -74,6 +82,144 @@ Status CheckCompressionAvailable(std::string_view compression_name,
return {};
}

template <typename ArrowArrayType, typename ValueType>
Status UpdateFloatingFieldMetrics(int32_t field_id, const ::arrow::Array& arrow_array,
const std::vector<uint8_t>* valid_rows,
std::unordered_map<int32_t, FieldMetrics>& metrics) {
constexpr auto expected_type_id =
std::is_same_v<ValueType, float> ? ::arrow::Type::FLOAT : ::arrow::Type::DOUBLE;
ICEBERG_PRECHECK(arrow_array.type_id() == expected_type_id,
"Expected Arrow floating-point array for field metrics collection");
const auto& array = static_cast<const ArrowArrayType&>(arrow_array);
auto& field_metrics = metrics[field_id];
field_metrics.field_id = field_id;
if (field_metrics.value_count < 0) {
field_metrics.value_count = 0;
}
if (field_metrics.null_value_count < 0) {
field_metrics.null_value_count = 0;
}
if (field_metrics.nan_value_count < 0) {
field_metrics.nan_value_count = 0;
}

field_metrics.value_count += array.length();

for (int64_t i = 0; i < array.length(); ++i) {
if ((valid_rows != nullptr && (*valid_rows)[i] == 0) || array.IsNull(i)) {
++field_metrics.null_value_count;
continue;
}

ValueType value = array.Value(i);
if (std::isnan(value)) {
++field_metrics.nan_value_count;
continue;
}

auto literal = [&]() {
if constexpr (std::is_same_v<ValueType, float>) {
return Literal::Float(value);
} else {
return Literal::Double(value);
}
}();
if (!field_metrics.lower_bound.has_value() ||
literal < field_metrics.lower_bound.value()) {
field_metrics.lower_bound = literal;
}
if (!field_metrics.upper_bound.has_value() ||
literal > field_metrics.upper_bound.value()) {
field_metrics.upper_bound = std::move(literal);
}
}

return {};
}

std::optional<std::vector<uint8_t>> BuildValidRows(const ::arrow::Array& array,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking: this could be made cheaper by carrying Arrow validity bitmaps instead of materializing a std::vector<uint8_t> and calling array.IsNull(i) for every row at each struct level. The main detail is that this needs to preserve the current parent-validity behavior: child arrays of a null StructArray do not necessarily have those parent nulls reflected in their own null bitmap, so the effective validity should be parent_validity AND array.null_bitmap().

One possible approach is to keep a bitmap/buffer plus offset for the current effective validity, and when descending into a field, use Arrow bitmap utilities to AND it with the child array validity bitmap (or just reuse the parent bitmap when the child has null_count() == 0). Then the float/double collector can iterate only the effective bitmap bits instead of indexing a byte vector. This should avoid per-level byte-vector allocation and reduce the per-row validity checks for wide/deep struct writes.

const std::vector<uint8_t>* parent) {
if (parent == nullptr && array.null_count() == 0) {
return std::nullopt;
}

std::vector<uint8_t> valid_rows(array.length(), 1);
for (int64_t i = 0; i < array.length(); ++i) {
if ((parent != nullptr && (*parent)[i] == 0) || array.IsNull(i)) {
valid_rows[i] = 0;
}
}
return valid_rows;
}

class FieldMetricsCollector {
public:
FieldMetricsCollector(std::unordered_map<int32_t, FieldMetrics>& metrics,
const MetricsConfig& metrics_config, const Schema& schema)
: metrics_(metrics), metrics_config_(metrics_config), schema_(schema) {}

Status VisitStruct(const StructType& type, const ::arrow::Array& array) {
ICEBERG_PRECHECK(array.type_id() == ::arrow::Type::STRUCT,
"Expected Arrow struct array for Iceberg struct metrics collection");
const auto& struct_array = static_cast<const ::arrow::StructArray&>(array);
ICEBERG_PRECHECK(
struct_array.num_fields() == type.fields().size(),
"Arrow struct field count does not match Iceberg struct field count");

for (int i = 0; i < struct_array.num_fields(); ++i) {
ICEBERG_RETURN_UNEXPECTED(VisitField(type.fields()[i], *struct_array.field(i)));
}
return {};
}

Status VisitList(const ListType& /*type*/, const ::arrow::Array& /*array*/) {
return {};
}

Status VisitMap(const MapType& /*type*/, const ::arrow::Array& /*array*/) { return {}; }

Status VisitPrimitive(const PrimitiveType& type, const ::arrow::Array& array) {
switch (type.type_id()) {
case TypeId::kFloat:
return UpdateFloatingFieldMetrics<::arrow::FloatArray, float>(
field_id_, array, valid_rows_, metrics_);
case TypeId::kDouble:
return UpdateFloatingFieldMetrics<::arrow::DoubleArray, double>(
field_id_, array, valid_rows_, metrics_);
default:
return {};
}
}

private:
Status VisitField(const SchemaField& field, const ::arrow::Array& array) {
// Skip metrics collection for fields whose mode is kNone in MetricsConfig.
ICEBERG_ASSIGN_OR_RAISE(auto column_name,
schema_.FindColumnNameById(field.field_id()));
if (column_name.has_value() && metrics_config_.ColumnMode(column_name.value()).kind ==
MetricsMode::Kind::kNone) {
return {};
}

auto previous_valid_rows = valid_rows_;
auto field_valid_rows = BuildValidRows(array, previous_valid_rows);
if (field_valid_rows.has_value()) {
valid_rows_ = &field_valid_rows.value();
}

field_id_ = field.field_id();
auto status = VisitTypeCategory(*field.type(), this, array);
valid_rows_ = previous_valid_rows;
return status;
}

std::unordered_map<int32_t, FieldMetrics>& metrics_;
const MetricsConfig& metrics_config_;
const Schema& schema_;
const std::vector<uint8_t>* valid_rows_ = nullptr;
int32_t field_id_ = -1;
};

Result<std::optional<int32_t>> ParseCodecLevel(const WriterProperties& properties) {
auto level_str = properties.Get(WriterProperties::kParquetCompressionLevel);
if (level_str.empty()) {
Expand Down Expand Up @@ -136,6 +282,12 @@ class ParquetWriter::Impl {
ICEBERG_ARROW_ASSIGN_OR_RETURN(auto batch,
::arrow::ImportRecordBatch(array, arrow_schema_));

ICEBERG_ARROW_ASSIGN_OR_RETURN(auto struct_array, batch->ToStructArray());
FieldMetricsCollector field_metrics_collector(field_metrics_, *metrics_config_,
*schema_);
ICEBERG_RETURN_UNEXPECTED(
field_metrics_collector.VisitStruct(*schema_, *struct_array));

ICEBERG_ARROW_RETURN_NOT_OK(writer_->WriteRecordBatch(*batch));

return {};
Expand Down Expand Up @@ -179,9 +331,8 @@ class ParquetWriter::Impl {
ICEBERG_PRECHECK(writer_ == nullptr, "Cannot return metrics for unclosed writer");
ICEBERG_PRECHECK(metadata_ != nullptr,
"Cannot return metrics because Parquet metadata is not available");
// TODO(WZhuo): collect write-side FieldMetrics to support NaN value counts.
return ParquetMetrics::GetMetrics(*schema_, *parquet_schema_, *metrics_config_,
*metadata_);
*metadata_, field_metrics_);
}

private:
Expand All @@ -205,6 +356,8 @@ class ParquetWriter::Impl {
int64_t total_bytes_{0};
// Row group start offsets in the Parquet file.
std::vector<int64_t> split_offsets_;
// Write-side metrics for fields whose Parquet footer metrics are incomplete.
std::unordered_map<int32_t, FieldMetrics> field_metrics_;
};

ParquetWriter::~ParquetWriter() = default;
Expand Down
20 changes: 12 additions & 8 deletions src/iceberg/test/metrics_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,9 @@ void MetricsTestBase::MetricsForNaNColumns() {

ASSERT_TRUE(metrics.row_count.has_value()) << "row_count should be set";
EXPECT_EQ(*metrics.row_count, 2);
AssertCounts(1, 2, 0, metrics);
AssertCounts(2, 2, 0, metrics);
auto expected_nan_count = ReportsNanCounts() ? std::optional<int64_t>(2) : std::nullopt;
AssertCounts(1, 2, 0, expected_nan_count, metrics);
AssertCounts(2, 2, 0, expected_nan_count, metrics);

// When all values are NaN, bounds should not be set
AssertBounds<float>(1, float32(), std::nullopt, std::nullopt, metrics);
Expand Down Expand Up @@ -579,8 +580,9 @@ void MetricsTestBase::ColumnBoundsWithNaNValueAtFront() {

ASSERT_TRUE(metrics.row_count.has_value()) << "row_count should be set";
EXPECT_EQ(*metrics.row_count, 3);
AssertCounts(1, 3, 0, metrics);
AssertCounts(2, 3, 0, metrics);
auto expected_nan_count = ReportsNanCounts() ? std::optional<int64_t>(1) : std::nullopt;
AssertCounts(1, 3, 0, expected_nan_count, metrics);
AssertCounts(2, 3, 0, expected_nan_count, metrics);

// Bounds should be computed from non-NaN values
if (metrics.lower_bounds.contains(1)) {
Expand Down Expand Up @@ -619,8 +621,9 @@ void MetricsTestBase::ColumnBoundsWithNaNValueInMiddle() {

ASSERT_TRUE(metrics.row_count.has_value()) << "row_count should be set";
EXPECT_EQ(*metrics.row_count, 3);
AssertCounts(1, 3, 0, metrics);
AssertCounts(2, 3, 0, metrics);
auto expected_nan_count = ReportsNanCounts() ? std::optional<int64_t>(1) : std::nullopt;
AssertCounts(1, 3, 0, expected_nan_count, metrics);
AssertCounts(2, 3, 0, expected_nan_count, metrics);

if (metrics.lower_bounds.contains(1)) {
AssertBounds<float>(1, float32(), 1.2F, 5.6F, metrics);
Expand Down Expand Up @@ -658,8 +661,9 @@ void MetricsTestBase::ColumnBoundsWithNaNValueAtEnd() {

ASSERT_TRUE(metrics.row_count.has_value()) << "row_count should be set";
EXPECT_EQ(*metrics.row_count, 3);
AssertCounts(1, 3, 0, metrics);
AssertCounts(2, 3, 0, metrics);
auto expected_nan_count = ReportsNanCounts() ? std::optional<int64_t>(1) : std::nullopt;
AssertCounts(1, 3, 0, expected_nan_count, metrics);
AssertCounts(2, 3, 0, expected_nan_count, metrics);

if (metrics.lower_bounds.contains(1)) {
AssertBounds<float>(1, float32(), 1.2F, 5.6F, metrics);
Expand Down
3 changes: 3 additions & 0 deletions src/iceberg/test/metrics_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class MetricsTestBase {
/// \brief Whether the format supports small row groups for testing
virtual bool SupportsSmallRowGroups() const { return false; }

/// \brief Whether the format reports NaN counts for floating-point fields
virtual bool ReportsNanCounts() const { return false; }

// Helper methods for assertions
void AssertCounts(int field_id, std::optional<int64_t> expected_value_count,
std::optional<int64_t> expected_null_count, const Metrics& metrics);
Expand Down
1 change: 1 addition & 0 deletions src/iceberg/test/parquet_metrics_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class ParquetMetricsTest : public MetricsTestBase, public ::testing::Test {
}

bool SupportsSmallRowGroups() const override { return true; }
bool ReportsNanCounts() const override { return true; }

private:
std::string temp_parquet_file_;
Expand Down
Loading