diff --git a/src/iceberg/test/truncate_util_test.cc b/src/iceberg/test/truncate_util_test.cc index f39e10cae..d0299ca24 100644 --- a/src/iceberg/test/truncate_util_test.cc +++ b/src/iceberg/test/truncate_util_test.cc @@ -51,6 +51,26 @@ TEST(TruncateUtilTest, TruncateLiteral) { Literal::Binary(std::vector(expected.begin(), expected.end()))); } +TEST(TruncateUtilTest, TruncateLiteralRejectsInvalidWidth) { + std::vector data{1, 2, 3}; + + auto expect_invalid_width = [](const auto& result) { + EXPECT_THAT(result, IsError(ErrorKind::kInvalidArgument)); + EXPECT_THAT(result, HasErrorMessage("Width must be positive")); + }; + + for (int32_t width : {0, -1}) { + SCOPED_TRACE(width); + expect_invalid_width(TruncateUtils::TruncateLiteral(Literal::Int(1), width)); + expect_invalid_width(TruncateUtils::TruncateLiteral(Literal::Long(1), width)); + expect_invalid_width( + TruncateUtils::TruncateLiteral(Literal::Decimal(1065, 4, 2), width)); + expect_invalid_width( + TruncateUtils::TruncateLiteral(Literal::String("iceberg"), width)); + expect_invalid_width(TruncateUtils::TruncateLiteral(Literal::Binary(data), width)); + } +} + TEST(TruncateUtilTest, TruncateBinaryMax) { std::vector test1{1, 1, 2}; std::vector test2{1, 1, 0xFF, 2}; @@ -190,4 +210,27 @@ TEST(TruncateUtilTest, TruncateStringMax) { EXPECT_EQ(result9_2, Literal::String(test9_2_expected)); } +TEST(TruncateUtilTest, TruncateLiteralMaxRejectsInvalidWidth) { + std::vector data{1, 2, 3}; + + auto expect_invalid_width = [](const auto& result) { + EXPECT_THAT(result, IsError(ErrorKind::kInvalidArgument)); + EXPECT_THAT(result, HasErrorMessage("Width must be positive")); + }; + + for (int32_t width : {0, -1}) { + SCOPED_TRACE(width); + expect_invalid_width( + TruncateUtils::TruncateLiteralMax(Literal::String("iceberg"), width)); + expect_invalid_width( + TruncateUtils::TruncateLiteralMax(Literal::Binary(data), width)); + } +} + +TEST(TruncateUtilTest, TruncateUTF8MaxRejectsZeroWidth) { + auto result = TruncateUtils::TruncateUTF8Max("iceberg", 0); + EXPECT_THAT(result, IsError(ErrorKind::kInvalidArgument)); + EXPECT_THAT(result, HasErrorMessage("Width must be positive")); +} + } // namespace iceberg diff --git a/src/iceberg/util/truncate_util.cc b/src/iceberg/util/truncate_util.cc index 1778000f9..a7622578d 100644 --- a/src/iceberg/util/truncate_util.cc +++ b/src/iceberg/util/truncate_util.cc @@ -26,6 +26,7 @@ #include "iceberg/expression/literal.h" #include "iceberg/type.h" #include "iceberg/util/checked_cast.h" +#include "iceberg/util/macros.h" namespace iceberg { @@ -34,6 +35,20 @@ constexpr uint32_t kUtf8MaxCodePoint = 0x10FFFF; constexpr uint32_t kUtf8MinSurrogate = 0xD800; constexpr uint32_t kUtf8MaxSurrogate = 0xDFFF; +Status ValidateTruncateWidth(int32_t width) { + if (width <= 0) { + return InvalidArgument("Width must be positive, got {}", width); + } + return {}; +} + +Status ValidateTruncateWidth(size_t width) { + if (width == 0) { + return InvalidArgument("Width must be positive, got 0"); + } + return {}; +} + std::optional DecodeUtf8CodePoint(std::string_view source) { if (source.empty()) { return std::nullopt; @@ -205,6 +220,8 @@ Result> TruncateLiteralMaxImpl( Result> TruncateUtils::TruncateUTF8Max( const std::string& source, size_t L) { + ICEBERG_RETURN_UNEXPECTED(ValidateTruncateWidth(L)); + std::string truncated = TruncateUTF8(source, L); if (truncated == source) { return truncated; @@ -253,6 +270,8 @@ Decimal TruncateUtils::TruncateDecimal(const Decimal& decimal, int32_t width) { return TruncateLiteralImpl(literal, width); Result TruncateUtils::TruncateLiteral(const Literal& literal, int32_t width) { + ICEBERG_RETURN_UNEXPECTED(ValidateTruncateWidth(width)); + if (literal.IsNull()) [[unlikely]] { // Return null as is return literal; @@ -280,6 +299,8 @@ Result TruncateUtils::TruncateLiteral(const Literal& literal, int32_t w Result> TruncateUtils::TruncateLiteralMax(const Literal& literal, int32_t width) { + ICEBERG_RETURN_UNEXPECTED(ValidateTruncateWidth(width)); + if (literal.IsNull()) [[unlikely]] { // Return null as is return literal;