Skip to content

Commit 4dcd551

Browse files
committed
[Feature](func) Support function QUANTILE_STATE_TO/FROM_BASE64
1 parent 15194aa commit 4dcd551

File tree

9 files changed

+624
-0
lines changed

9 files changed

+624
-0
lines changed

be/src/vec/functions/function_quantile_state.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "common/compiler_util.h" // IWYU pragma: keep
3333
#include "common/status.h"
3434
#include "util/quantile_state.h"
35+
#include "util/url_coding.h"
3536
#include "vec/aggregate_functions/aggregate_function.h"
3637
#include "vec/columns/column.h"
3738
#include "vec/columns/column_complex.h"
@@ -49,9 +50,11 @@
4950
#include "vec/data_types/data_type_nullable.h"
5051
#include "vec/data_types/data_type_number.h"
5152
#include "vec/data_types/data_type_quantilestate.h" // IWYU pragma: keep
53+
#include "vec/data_types/data_type_string.h"
5254
#include "vec/functions/function.h"
5355
#include "vec/functions/function_const.h"
5456
#include "vec/functions/function_helpers.h"
57+
#include "vec/functions/function_totype.h"
5558
#include "vec/functions/simple_function_factory.h"
5659
#include "vec/utils/util.hpp"
5760

@@ -218,10 +221,134 @@ class FunctionQuantileStatePercent : public IFunction {
218221
}
219222
};
220223

224+
class FunctionQuantileStateFromBase64 : public IFunction {
225+
public:
226+
static constexpr auto name = "quantile_state_from_base64";
227+
String get_name() const override { return name; }
228+
229+
static FunctionPtr create() { return std::make_shared<FunctionQuantileStateFromBase64>(); }
230+
231+
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
232+
return std::make_shared<DataTypeNullable>(std::make_shared<DataTypeQuantileState>());
233+
}
234+
235+
size_t get_number_of_arguments() const override { return 1; }
236+
237+
bool use_default_implementation_for_nulls() const override { return true; }
238+
239+
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
240+
size_t result, size_t input_rows_count) const override {
241+
auto res_null_map = ColumnUInt8::create(input_rows_count, 0);
242+
auto res_data_column = ColumnQuantileState::create();
243+
auto& null_map = res_null_map->get_data();
244+
auto& res = res_data_column->get_data();
245+
246+
auto& argument_column = block.get_by_position(arguments[0]).column;
247+
const auto& str_column = static_cast<const ColumnString&>(*argument_column);
248+
const ColumnString::Chars& data = str_column.get_chars();
249+
const ColumnString::Offsets& offsets = str_column.get_offsets();
250+
251+
res.reserve(input_rows_count);
252+
253+
std::string decode_buff;
254+
int last_decode_buff_len = 0;
255+
int curr_decode_buff_len = 0;
256+
for (size_t i = 0; i < input_rows_count; ++i) {
257+
const char* src_str = reinterpret_cast<const char*>(&data[offsets[i - 1]]);
258+
int64_t src_size = offsets[i] - offsets[i - 1];
259+
260+
if (src_size == 0 || 0 != src_size % 4) {
261+
res.emplace_back();
262+
null_map[i] = 1;
263+
continue;
264+
}
265+
266+
curr_decode_buff_len = src_size + 3;
267+
if (curr_decode_buff_len > last_decode_buff_len) {
268+
decode_buff.resize(curr_decode_buff_len);
269+
last_decode_buff_len = curr_decode_buff_len;
270+
}
271+
auto outlen = base64_decode(src_str, src_size, decode_buff.data());
272+
if (outlen < 0) {
273+
res.emplace_back();
274+
null_map[i] = 1;
275+
} else {
276+
doris::Slice decoded_slice(decode_buff.data(), outlen);
277+
doris::QuantileState quantile_state;
278+
if (!quantile_state.deserialize(decoded_slice)) {
279+
return Status::RuntimeError(fmt::format(
280+
"quantile_state_from_base64 decode failed: base64: {}", src_str));
281+
} else {
282+
res.emplace_back(std::move(quantile_state));
283+
}
284+
}
285+
}
286+
287+
block.get_by_position(result).column =
288+
ColumnNullable::create(std::move(res_data_column), std::move(res_null_map));
289+
return Status::OK();
290+
}
291+
};
292+
293+
struct NameQuantileStateToBase64 {
294+
static constexpr auto name = "quantile_state_to_base64";
295+
};
296+
297+
struct QuantileStateToBase64 {
298+
using ReturnType = DataTypeString;
299+
static constexpr auto TYPE_INDEX = TypeIndex::QuantileState;
300+
using Type = DataTypeQuantileState::FieldType;
301+
using ReturnColumnType = ColumnString;
302+
using Chars = ColumnString::Chars;
303+
using Offsets = ColumnString::Offsets;
304+
305+
static Status vector(const std::vector<QuantileState>& data, Chars& chars, Offsets& offsets) {
306+
size_t size = data.size();
307+
offsets.resize(size);
308+
size_t output_char_size = 0;
309+
for (size_t i = 0; i < size; ++i) {
310+
auto& quantile_state_val = const_cast<QuantileState&>(data[i]);
311+
auto ser_size = quantile_state_val.get_serialized_size();
312+
output_char_size += (int)(4.0 * ceil((double)ser_size / 3.0));
313+
}
314+
ColumnString::check_chars_length(output_char_size, size);
315+
chars.resize(output_char_size);
316+
auto* chars_data = chars.data();
317+
318+
size_t cur_ser_size = 0;
319+
size_t last_ser_size = 0;
320+
std::string ser_buff;
321+
size_t encoded_offset = 0;
322+
for (size_t i = 0; i < size; ++i) {
323+
auto& quantile_state_val = const_cast<QuantileState&>(data[i]);
324+
325+
cur_ser_size = quantile_state_val.get_serialized_size();
326+
if (cur_ser_size > last_ser_size) {
327+
last_ser_size = cur_ser_size;
328+
ser_buff.resize(cur_ser_size);
329+
}
330+
size_t real_size =
331+
quantile_state_val.serialize(reinterpret_cast<uint8_t*>(ser_buff.data()));
332+
auto outlen = base64_encode((const unsigned char*)ser_buff.data(), real_size,
333+
chars_data + encoded_offset);
334+
DCHECK(outlen > 0);
335+
336+
encoded_offset += outlen;
337+
offsets[i] = encoded_offset;
338+
}
339+
return Status::OK();
340+
}
341+
};
342+
343+
using FunctionQuantileStateToBase64 =
344+
FunctionUnaryToType<QuantileStateToBase64, NameQuantileStateToBase64>;
345+
221346
void register_function_quantile_state(SimpleFunctionFactory& factory) {
222347
factory.register_function<FunctionConst<QuantileStateEmpty, false>>();
223348
factory.register_function<FunctionQuantileStatePercent>();
224349
factory.register_function<FunctionToQuantileState>();
350+
factory.register_function<FunctionQuantileStateFromBase64>();
351+
factory.register_function<FunctionQuantileStateToBase64>();
225352
}
226353

227354
} // namespace doris::vectorized
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
#include <gtest/gtest.h>
18+
19+
#include <string>
20+
21+
#include "function_test_util.h"
22+
#include "util/quantile_state.h"
23+
#include "util/url_coding.h"
24+
#include "vec/core/types.h"
25+
#include "vec/data_types/data_type_quantilestate.h"
26+
#include "vec/data_types/data_type_string.h"
27+
28+
namespace doris::vectorized {
29+
30+
TEST(function_quantile_state_test, function_quantile_state_to_base64) {
31+
std::string func_name = "quantile_state_to_base64";
32+
InputTypeSet input_types = {TypeIndex::QuantileState};
33+
34+
QuantileState empty_quantile_state;
35+
36+
QuantileState single_quantile_state;
37+
single_quantile_state.add_value(1.0);
38+
39+
QuantileState multi_quantile_state;
40+
multi_quantile_state.add_value(1.0);
41+
multi_quantile_state.add_value(2.0);
42+
multi_quantile_state.add_value(3.0);
43+
multi_quantile_state.add_value(4.0);
44+
multi_quantile_state.add_value(5.0);
45+
46+
QuantileState explicit_quantile_state;
47+
for (int i = 0; i < 100; i++) {
48+
explicit_quantile_state.add_value(static_cast<double>(i));
49+
}
50+
51+
QuantileState tdigest_quantile_state;
52+
for (int i = 0; i < 3000; i++) {
53+
tdigest_quantile_state.add_value(static_cast<double>(i));
54+
}
55+
56+
uint8_t buf[65536];
57+
unsigned char encoded_buf[131072];
58+
59+
std::string empty_base64;
60+
{
61+
size_t len = empty_quantile_state.serialize(buf);
62+
size_t encoded_len = base64_encode(buf, len, encoded_buf);
63+
empty_base64 = std::string(reinterpret_cast<char*>(encoded_buf), encoded_len);
64+
}
65+
66+
std::string single_base64;
67+
{
68+
size_t len = single_quantile_state.serialize(buf);
69+
size_t encoded_len = base64_encode(buf, len, encoded_buf);
70+
single_base64 = std::string(reinterpret_cast<char*>(encoded_buf), encoded_len);
71+
}
72+
73+
std::string multi_base64;
74+
{
75+
size_t len = multi_quantile_state.serialize(buf);
76+
size_t encoded_len = base64_encode(buf, len, encoded_buf);
77+
multi_base64 = std::string(reinterpret_cast<char*>(encoded_buf), encoded_len);
78+
}
79+
80+
std::string explicit_base64;
81+
{
82+
size_t len = explicit_quantile_state.serialize(buf);
83+
size_t encoded_len = base64_encode(buf, len, encoded_buf);
84+
explicit_base64 = std::string(reinterpret_cast<char*>(encoded_buf), encoded_len);
85+
}
86+
87+
std::string tdigest_base64;
88+
{
89+
size_t len = tdigest_quantile_state.serialize(buf);
90+
size_t encoded_len = base64_encode(buf, len, encoded_buf);
91+
tdigest_base64 = std::string(reinterpret_cast<char*>(encoded_buf), encoded_len);
92+
}
93+
94+
{
95+
DataSet data_set = {{{&empty_quantile_state}, empty_base64},
96+
{{&single_quantile_state}, single_base64},
97+
{{&multi_quantile_state}, multi_base64},
98+
{{&explicit_quantile_state}, explicit_base64},
99+
{{&tdigest_quantile_state}, tdigest_base64}};
100+
101+
static_cast<void>(check_function<DataTypeString, true>(func_name, input_types, data_set));
102+
}
103+
}
104+
105+
TEST(function_quantile_state_test, function_quantile_state_from_base64) {
106+
std::string func_name = "quantile_state_from_base64";
107+
InputTypeSet input_types = {TypeIndex::String};
108+
109+
// Create quantile states for comparison
110+
QuantileState empty_quantile_state;
111+
112+
QuantileState single_quantile_state;
113+
single_quantile_state.add_value(1.0);
114+
115+
QuantileState multi_quantile_state;
116+
multi_quantile_state.add_value(1.0);
117+
multi_quantile_state.add_value(2.0);
118+
multi_quantile_state.add_value(3.0);
119+
multi_quantile_state.add_value(4.0);
120+
multi_quantile_state.add_value(5.0);
121+
122+
uint8_t buf[65536];
123+
unsigned char encoded_buf[131072];
124+
std::string empty_base64;
125+
std::string single_base64;
126+
std::string multi_base64;
127+
128+
{
129+
size_t len = empty_quantile_state.serialize(buf);
130+
size_t encoded_len = base64_encode(buf, len, encoded_buf);
131+
empty_base64 = std::string(reinterpret_cast<char*>(encoded_buf), encoded_len);
132+
}
133+
134+
{
135+
size_t len = single_quantile_state.serialize(buf);
136+
size_t encoded_len = base64_encode(buf, len, encoded_buf);
137+
single_base64 = std::string(reinterpret_cast<char*>(encoded_buf), encoded_len);
138+
}
139+
140+
{
141+
size_t len = multi_quantile_state.serialize(buf);
142+
size_t encoded_len = base64_encode(buf, len, encoded_buf);
143+
multi_base64 = std::string(reinterpret_cast<char*>(encoded_buf), encoded_len);
144+
}
145+
146+
{
147+
char decoded_buf[65536];
148+
int decoded_len = base64_decode(empty_base64.c_str(), empty_base64.length(), decoded_buf);
149+
EXPECT_GT(decoded_len, 0);
150+
151+
QuantileState decoded_empty;
152+
doris::Slice slice(decoded_buf, decoded_len);
153+
EXPECT_TRUE(decoded_empty.deserialize(slice));
154+
155+
EXPECT_TRUE(std::isnan(empty_quantile_state.get_value_by_percentile(0.5)));
156+
EXPECT_TRUE(std::isnan(decoded_empty.get_value_by_percentile(0.5)));
157+
}
158+
159+
{
160+
char decoded_buf[65536];
161+
int decoded_len = base64_decode(single_base64.c_str(), single_base64.length(), decoded_buf);
162+
EXPECT_GT(decoded_len, 0);
163+
164+
QuantileState decoded_single;
165+
doris::Slice slice(decoded_buf, decoded_len);
166+
EXPECT_TRUE(decoded_single.deserialize(slice));
167+
168+
EXPECT_NEAR(single_quantile_state.get_value_by_percentile(0.5),
169+
decoded_single.get_value_by_percentile(0.5), 0.01);
170+
}
171+
172+
{
173+
char decoded_buf[65536];
174+
int decoded_len = base64_decode(multi_base64.c_str(), multi_base64.length(), decoded_buf);
175+
EXPECT_GT(decoded_len, 0);
176+
177+
QuantileState decoded_multi;
178+
doris::Slice slice(decoded_buf, decoded_len);
179+
EXPECT_TRUE(decoded_multi.deserialize(slice));
180+
181+
EXPECT_NEAR(multi_quantile_state.get_value_by_percentile(0.5),
182+
decoded_multi.get_value_by_percentile(0.5), 0.01);
183+
EXPECT_NEAR(multi_quantile_state.get_value_by_percentile(0.9),
184+
decoded_multi.get_value_by_percentile(0.9), 0.01);
185+
}
186+
}
187+
188+
TEST(function_quantile_state_test, function_quantile_state_roundtrip) {
189+
QuantileState original;
190+
for (int i = 0; i < 50; i++) {
191+
original.add_value(static_cast<double>(i * 2));
192+
}
193+
194+
uint8_t ser_buf[65536];
195+
size_t ser_len = original.serialize(ser_buf);
196+
197+
unsigned char encoded_buf[131072];
198+
size_t encoded_len = base64_encode(ser_buf, ser_len, encoded_buf);
199+
std::string base64_str(reinterpret_cast<char*>(encoded_buf), encoded_len);
200+
201+
char decoded_buf[65536];
202+
int decoded_len = base64_decode(base64_str.c_str(), base64_str.length(), decoded_buf);
203+
EXPECT_GT(decoded_len, 0);
204+
205+
QuantileState recovered;
206+
doris::Slice slice(decoded_buf, decoded_len);
207+
EXPECT_TRUE(recovered.deserialize(slice));
208+
209+
EXPECT_NEAR(original.get_value_by_percentile(0.5), recovered.get_value_by_percentile(0.5),
210+
0.01);
211+
EXPECT_NEAR(original.get_value_by_percentile(0.9), recovered.get_value_by_percentile(0.9),
212+
0.01);
213+
}
214+
215+
} // namespace doris::vectorized

0 commit comments

Comments
 (0)