Skip to content

Commit 25ea1db

Browse files
committed
[Enhancement](pyudf) Support empty arg pyudf && udtf
1 parent aad73a9 commit 25ea1db

6 files changed

Lines changed: 218 additions & 9 deletions

File tree

be/src/exprs/function/function_python_udf.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ Status PythonFunctionCall::execute_impl(FunctionContext* context, Block& block,
113113
return Status::InternalError("Python UDF client is null");
114114
}
115115

116-
int64_t input_rows = block.rows();
116+
int64_t input_rows = num_rows;
117117
uint32_t input_columns = block.columns();
118118
DCHECK(input_columns > 0 && result < input_columns &&
119119
_argument_types.size() == arguments.size());
@@ -142,8 +142,13 @@ Status PythonFunctionCall::execute_impl(FunctionContext* context, Block& block,
142142
std::shared_ptr<arrow::RecordBatch> input_batch;
143143
std::shared_ptr<arrow::RecordBatch> output_batch;
144144
cctz::time_zone _timezone_obj; // default UTC
145-
RETURN_IF_ERROR(convert_to_arrow_batch(input_block, schema, arrow::default_memory_pool(),
146-
&input_batch, _timezone_obj));
145+
if (arguments.empty()) {
146+
input_batch = arrow::RecordBatch::Make(schema, input_rows,
147+
std::vector<std::shared_ptr<arrow::Array>> {});
148+
} else {
149+
RETURN_IF_ERROR(convert_to_arrow_batch(input_block, schema, arrow::default_memory_pool(),
150+
&input_batch, _timezone_obj));
151+
}
147152
RETURN_IF_ERROR(client->evaluate(*input_batch, &output_batch));
148153
int64_t output_rows = output_batch->num_rows();
149154

be/src/exprs/table_function/python_udtf_function.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,19 @@ Status PythonUDTFFunction::process_init(Block* block, RuntimeState* state) {
132132
for (uint32_t i = 0; i < child_column_idxs.size(); ++i) {
133133
input_block.insert(block->get_by_position(child_column_idxs[i]));
134134
}
135+
int64_t input_rows = block->rows();
135136
std::shared_ptr<arrow::Schema> input_schema;
136137
std::shared_ptr<arrow::RecordBatch> input_batch;
137138
RETURN_IF_ERROR(get_arrow_schema_from_block(input_block, &input_schema,
138139
TimezoneUtils::default_time_zone));
139-
RETURN_IF_ERROR(convert_to_arrow_batch(input_block, input_schema, arrow::default_memory_pool(),
140-
&input_batch, _timezone_obj));
140+
if (child_column_idxs.empty()) {
141+
input_batch = arrow::RecordBatch::Make(input_schema, input_rows,
142+
std::vector<std::shared_ptr<arrow::Array>> {});
143+
} else {
144+
RETURN_IF_ERROR(convert_to_arrow_batch(input_block, input_schema,
145+
arrow::default_memory_pool(), &input_batch,
146+
_timezone_obj));
147+
}
141148

142149
// Step 3: Call Python UDTF to evaluate all rows at once (similar to Java UDTF's JNI call)
143150
// Python returns a ListArray where each element contains outputs for one input row

be/src/udf/python/python_udf_meta.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ namespace doris {
3232

3333
Status PythonUDFMeta::convert_types_to_schema(const DataTypes& types, const std::string& timezone,
3434
std::shared_ptr<arrow::Schema>* schema) {
35-
assert(!types.empty());
3635
arrow::SchemaBuilder builder;
3736
for (size_t i = 0; i < types.size(); ++i) {
3837
std::shared_ptr<arrow::DataType> arrow_type;
@@ -152,8 +151,9 @@ Status PythonUDFMeta::check() const {
152151
return Status::InvalidArgument("Python UDF runtime version is empty");
153152
}
154153

155-
if (input_types.empty()) {
156-
return Status::InvalidArgument("Python UDF input types is empty");
154+
if (input_types.empty() &&
155+
(client_type == PythonClientType::UDAF || type == PythonUDFLoadType::UNKNOWN)) {
156+
return Status::InvalidArgument("Python UDAF input types is empty");
157157
}
158158

159159
if (!return_type) {

be/test/udf/python/python_udf_meta_test.cpp

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,43 @@ TEST_F(PythonUDFMetaTest, CheckEmptyRuntimeVersion) {
109109
EXPECT_TRUE(status.to_string().find("runtime version is empty") != std::string::npos);
110110
}
111111

112-
TEST_F(PythonUDFMetaTest, CheckEmptyInputTypes) {
112+
TEST_F(PythonUDFMetaTest, CheckEmptyInputTypesAllowedForUdf) {
113113
PythonUDFMeta meta;
114114
meta.name = "test_udf";
115115
meta.symbol = "test_func";
116116
meta.runtime_version = "3.9.16";
117117
meta.input_types = {};
118118
meta.return_type = nullable_int32_;
119119
meta.type = PythonUDFLoadType::INLINE;
120+
meta.client_type = PythonClientType::UDF;
121+
122+
Status status = meta.check();
123+
EXPECT_TRUE(status.ok()) << status.to_string();
124+
}
125+
126+
TEST_F(PythonUDFMetaTest, CheckEmptyInputTypesAllowedForUdtf) {
127+
PythonUDFMeta meta;
128+
meta.name = "test_udtf";
129+
meta.symbol = "test_func";
130+
meta.runtime_version = "3.9.16";
131+
meta.input_types = {};
132+
meta.return_type = nullable_string_;
133+
meta.type = PythonUDFLoadType::INLINE;
134+
meta.client_type = PythonClientType::UDTF;
135+
136+
Status status = meta.check();
137+
EXPECT_TRUE(status.ok()) << status.to_string();
138+
}
139+
140+
TEST_F(PythonUDFMetaTest, CheckEmptyInputTypesRejectedForUdaf) {
141+
PythonUDFMeta meta;
142+
meta.name = "test_udaf";
143+
meta.symbol = "test_func";
144+
meta.runtime_version = "3.9.16";
145+
meta.input_types = {};
146+
meta.return_type = nullable_int32_;
147+
meta.type = PythonUDFLoadType::INLINE;
148+
meta.client_type = PythonClientType::UDAF;
120149

121150
Status status = meta.check();
122151
EXPECT_FALSE(status.ok());
@@ -401,6 +430,27 @@ TEST_F(PythonUDFMetaTest, SerializeToJsonMultipleInputTypes) {
401430
EXPECT_TRUE(doc.HasMember("input_types"));
402431
}
403432

433+
TEST_F(PythonUDFMetaTest, SerializeToJsonEmptyInputTypesForUdf) {
434+
PythonUDFMeta meta;
435+
meta.name = "zero_arg_udf";
436+
meta.symbol = "func";
437+
meta.runtime_version = "3.9.16";
438+
meta.input_types = {};
439+
meta.return_type = nullable_int32_;
440+
meta.type = PythonUDFLoadType::INLINE;
441+
meta.client_type = PythonClientType::UDF;
442+
443+
std::string json_str;
444+
Status status = meta.serialize_to_json(&json_str);
445+
EXPECT_TRUE(status.ok()) << status.to_string();
446+
447+
rapidjson::Document doc;
448+
doc.Parse(json_str.c_str());
449+
EXPECT_FALSE(doc.HasParseError());
450+
EXPECT_TRUE(doc.HasMember("input_types"));
451+
EXPECT_FALSE(std::string(doc["input_types"].GetString()).empty());
452+
}
453+
404454
// ============================================================================
405455
// PythonUDFMeta convert_types_to_schema() tests
406456
// ============================================================================
@@ -429,6 +479,17 @@ TEST_F(PythonUDFMetaTest, ConvertTypesToSchemaSingleType) {
429479
EXPECT_EQ(schema->num_fields(), 1);
430480
}
431481

482+
TEST_F(PythonUDFMetaTest, ConvertTypesToSchemaEmpty) {
483+
DataTypes types = {};
484+
std::shared_ptr<arrow::Schema> schema;
485+
486+
Status status = PythonUDFMeta::convert_types_to_schema(types, TimezoneUtils::default_time_zone,
487+
&schema);
488+
EXPECT_TRUE(status.ok()) << status.to_string();
489+
EXPECT_NE(schema, nullptr);
490+
EXPECT_EQ(schema->num_fields(), 0);
491+
}
492+
432493
// ============================================================================
433494
// PythonUDFMeta serialize_arrow_schema() tests
434495
// ============================================================================
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
suite("test_pythonudf_no_input") {
19+
def runtime_version = getPythonUdfRuntimeVersion()
20+
def table_name = "test_pythonudf_no_input_tbl"
21+
22+
try {
23+
sql """ DROP FUNCTION IF EXISTS py_const_no_input(); """
24+
sql """ DROP TABLE IF EXISTS ${table_name}; """
25+
26+
sql """
27+
CREATE FUNCTION py_const_no_input()
28+
RETURNS INT
29+
PROPERTIES (
30+
"type" = "PYTHON_UDF",
31+
"symbol" = "evaluate",
32+
"runtime_version" = "${runtime_version}"
33+
)
34+
AS \$\$
35+
def evaluate():
36+
return 7
37+
\$\$;
38+
"""
39+
40+
assert sql(""" SELECT py_const_no_input(); """)[0][0] == 7
41+
42+
sql """
43+
CREATE TABLE ${table_name} (
44+
id INT
45+
) ENGINE=OLAP
46+
DUPLICATE KEY(id)
47+
DISTRIBUTED BY HASH(id) BUCKETS 1
48+
PROPERTIES("replication_num" = "1");
49+
"""
50+
51+
sql """ INSERT INTO ${table_name} VALUES (1), (2), (3); """
52+
53+
def rows = sql("""
54+
SELECT id, py_const_no_input() AS v
55+
FROM ${table_name}
56+
ORDER BY id
57+
""")
58+
59+
assert rows.size() == 3 : "Expected 3 rows, got ${rows.size()}"
60+
assert rows.collect { it[0] as int } == [1, 2, 3]
61+
assert rows.every { (it[1] as int) == 7 }
62+
} finally {
63+
try_sql(""" DROP FUNCTION IF EXISTS py_const_no_input(); """)
64+
try_sql(""" DROP TABLE IF EXISTS ${table_name}; """)
65+
}
66+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
suite("test_pythonudtf_no_input") {
19+
def runtime_version = getPythonUdfRuntimeVersion()
20+
def table_name = "test_pythonudtf_no_input_tbl"
21+
22+
try {
23+
sql """ DROP FUNCTION IF EXISTS py_emit_no_input(); """
24+
sql """ DROP TABLE IF EXISTS ${table_name}; """
25+
26+
sql """
27+
CREATE TABLES FUNCTION py_emit_no_input()
28+
RETURNS ARRAY<STRING>
29+
PROPERTIES (
30+
"type" = "PYTHON_UDF",
31+
"symbol" = "emit_values",
32+
"runtime_version" = "${runtime_version}"
33+
)
34+
AS \$\$
35+
def emit_values():
36+
yield ('left',)
37+
yield ('right',)
38+
\$\$;
39+
"""
40+
41+
sql """
42+
CREATE TABLE ${table_name} (
43+
id INT
44+
) ENGINE=OLAP
45+
DUPLICATE KEY(id)
46+
DISTRIBUTED BY HASH(id) BUCKETS 1
47+
PROPERTIES("replication_num" = "1");
48+
"""
49+
50+
sql """ INSERT INTO ${table_name} VALUES (1), (2); """
51+
52+
def rows = sql("""
53+
SELECT id, value
54+
FROM ${table_name}
55+
LATERAL VIEW py_emit_no_input() tmp AS value
56+
ORDER BY id, value
57+
""")
58+
59+
assert rows.size() == 4 : "Expected 4 rows, got ${rows.size()}"
60+
assert rows.collect { [(it[0] as int), it[1].toString()] } == [
61+
[1, "left"],
62+
[1, "right"],
63+
[2, "left"],
64+
[2, "right"]
65+
]
66+
} finally {
67+
try_sql(""" DROP FUNCTION IF EXISTS py_emit_no_input(); """)
68+
try_sql(""" DROP TABLE IF EXISTS ${table_name}; """)
69+
}
70+
}

0 commit comments

Comments
 (0)