Skip to content

Commit 1b4540c

Browse files
add json_str_contains:
1 parent 78c5abb commit 1b4540c

3 files changed

Lines changed: 229 additions & 0 deletions

File tree

src/json_str_contains.rs

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
use std::any::Any;
2+
3+
use datafusion::arrow::array::BooleanArray;
4+
use datafusion::arrow::datatypes::DataType;
5+
use datafusion::common::{plan_err, Result};
6+
use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
7+
use jiter::Peek;
8+
9+
use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath};
10+
use crate::common_macros::make_udf_function;
11+
12+
make_udf_function!(
13+
JsonStrContains,
14+
json_str_contains,
15+
json_data path needle,
16+
r#"Checks if a JSON string value at the specified path contains the given substring"#
17+
);
18+
19+
#[derive(Debug, PartialEq, Eq, Hash)]
20+
pub(super) struct JsonStrContains {
21+
signature: Signature,
22+
aliases: [String; 1],
23+
}
24+
25+
impl Default for JsonStrContains {
26+
fn default() -> Self {
27+
Self {
28+
signature: Signature::any(3, Volatility::Immutable),
29+
aliases: ["json_str_contains".to_string()],
30+
}
31+
}
32+
}
33+
34+
impl ScalarUDFImpl for JsonStrContains {
35+
fn as_any(&self) -> &dyn Any {
36+
self
37+
}
38+
39+
fn name(&self) -> &str {
40+
self.aliases[0].as_str()
41+
}
42+
43+
fn signature(&self) -> &Signature {
44+
&self.signature
45+
}
46+
47+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
48+
if arg_types.len() != 3 {
49+
plan_err!("'json_str_contains' function requires exactly three arguments: json_data, path, and needle.")
50+
} else {
51+
return_type_check(arg_types, self.name(), DataType::Boolean).map(|_| DataType::Boolean)
52+
}
53+
}
54+
55+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
56+
invoke::<BooleanArray>(&args.args, jiter_json_str_contains)
57+
}
58+
59+
fn aliases(&self) -> &[String] {
60+
&self.aliases
61+
}
62+
}
63+
64+
fn jiter_json_str_contains(json_data: Option<&str>, path: &[JsonPath]) -> Result<bool, GetError> {
65+
if path.len() != 2 {
66+
return get_err!();
67+
}
68+
69+
let (path_str, needle) = match path {
70+
[JsonPath::Key(path_str), JsonPath::Key(needle_str)] => (path_str, needle_str),
71+
_ => return get_err!(),
72+
};
73+
74+
let parsed_path = path_str.split('.').map(|s| JsonPath::Key(s)).collect::<Vec<_>>();
75+
76+
let Some((mut jiter, Peek::String)) = jiter_json_find(json_data, &parsed_path) else {
77+
return Ok(false);
78+
};
79+
80+
let str_value = jiter.known_str()?;
81+
82+
Ok(str_value.contains(needle))
83+
}
84+
85+
#[cfg(test)]
86+
mod tests {
87+
use super::*;
88+
use datafusion::arrow::array::StringArray;
89+
use datafusion::common::ScalarValue;
90+
use datafusion::logical_expr::ColumnarValue;
91+
use std::sync::Arc;
92+
93+
#[test]
94+
fn test_json_str_contains_simple() {
95+
let json_data = r#"{"name": "Norm Macdonald", "title": "Software Engineer"}"#;
96+
let json_array = Arc::new(StringArray::from(vec![json_data]));
97+
98+
let args = vec![
99+
ColumnarValue::Array(json_array),
100+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("name".to_string()))),
101+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("Norm".to_string()))),
102+
];
103+
104+
let result = invoke::<BooleanArray>(&args, jiter_json_str_contains).unwrap();
105+
106+
match result {
107+
ColumnarValue::Array(arr) => {
108+
let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().unwrap();
109+
assert_eq!(bool_arr.len(), 1);
110+
assert_eq!(bool_arr.value(0), true);
111+
}
112+
_ => panic!("Expected array result"),
113+
}
114+
}
115+
116+
#[test]
117+
fn test_json_str_contains_not_found() {
118+
let json_data = r#"{"name": "Norm Macdonald"}"#;
119+
let json_array = Arc::new(StringArray::from(vec![json_data]));
120+
121+
let args = vec![
122+
ColumnarValue::Array(json_array),
123+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("name".to_string()))),
124+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("Dave".to_string()))),
125+
];
126+
127+
let result = invoke::<BooleanArray>(&args, jiter_json_str_contains).unwrap();
128+
129+
match result {
130+
ColumnarValue::Array(arr) => {
131+
let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().unwrap();
132+
assert_eq!(bool_arr.len(), 1);
133+
assert_eq!(bool_arr.value(0), false);
134+
}
135+
_ => panic!("Expected array result"),
136+
}
137+
}
138+
139+
#[test]
140+
fn test_json_str_contains_nested() {
141+
let json_data = r#"{"user": {"profile": {"name": "Norm Macdonald"}}}"#;
142+
let json_array = Arc::new(StringArray::from(vec![json_data]));
143+
144+
let args = vec![
145+
ColumnarValue::Array(json_array),
146+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("user.profile.name".to_string()))),
147+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("Macdonald".to_string()))),
148+
];
149+
150+
let result = invoke::<BooleanArray>(&args, jiter_json_str_contains).unwrap();
151+
152+
match result {
153+
ColumnarValue::Array(arr) => {
154+
let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().unwrap();
155+
assert_eq!(bool_arr.len(), 1);
156+
assert_eq!(bool_arr.value(0), true);
157+
}
158+
_ => panic!("Expected array result"),
159+
}
160+
}
161+
}

src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ mod json_get_json;
2020
mod json_get_str;
2121
mod json_length;
2222
mod json_object_keys;
23+
mod json_str_contains;
2324
mod rewrite;
2425

2526
pub use common_union::{JsonUnionEncoder, JsonUnionValue, JSON_UNION_DATA_TYPE};
@@ -37,6 +38,7 @@ pub mod functions {
3738
pub use crate::json_get_str::json_get_str;
3839
pub use crate::json_length::json_length;
3940
pub use crate::json_object_keys::json_object_keys;
41+
pub use crate::json_str_contains::json_str_contains;
4042
}
4143

4244
pub mod udfs {
@@ -52,6 +54,7 @@ pub mod udfs {
5254
pub use crate::json_get_str::json_get_str_udf;
5355
pub use crate::json_length::json_length_udf;
5456
pub use crate::json_object_keys::json_object_keys_udf;
57+
pub use crate::json_str_contains::json_str_contains_udf;
5558
}
5659

5760
/// Register all JSON UDFs, and [`rewrite::JsonFunctionRewriter`] with the provided [`FunctionRegistry`].
@@ -77,6 +80,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
7780
json_length::json_length_udf(),
7881
json_object_keys::json_object_keys_udf(),
7982
json_from_scalar::json_from_scalar_udf(),
83+
json_str_contains::json_str_contains_udf(),
8084
];
8185
functions.into_iter().try_for_each(|udf| {
8286
let existing_udf = registry.register_udf(udf)?;

tests/main.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2677,3 +2677,67 @@ async fn test_json_from_scalar_null_column() {
26772677
];
26782678
assert_batches_eq!(expected, &batches);
26792679
}
2680+
2681+
#[tokio::test]
2682+
async fn test_json_str_contains_simple() {
2683+
let sql = r#"select json_str_contains('{"name": "Norm Macdonald"}', 'name', 'Norm')"#;
2684+
let batches = run_query(sql).await.unwrap();
2685+
assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string()));
2686+
}
2687+
2688+
#[tokio::test]
2689+
async fn test_json_str_contains_not_found() {
2690+
let sql = r#"select json_str_contains('{"name": "Norm Macdonald"}', 'name', 'Norma')"#;
2691+
let batches = run_query(sql).await.unwrap();
2692+
assert_eq!(display_val(batches).await, (DataType::Boolean, "false".to_string()));
2693+
}
2694+
2695+
#[tokio::test]
2696+
async fn test_json_str_contains_nested() {
2697+
let sql = r#"select json_str_contains('{"user": {"profile": {"name": "Norm Macdonald"}}}', 'user.profile.name', 'Macdonald')"#;
2698+
let batches = run_query(sql).await.unwrap();
2699+
assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string()));
2700+
}
2701+
2702+
#[tokio::test]
2703+
async fn test_json_str_contains_case_sensitive() {
2704+
let sql = r#"select json_str_contains('{"title": "Software Engineer"}', 'title', 'engineer')"#;
2705+
let batches = run_query(sql).await.unwrap();
2706+
assert_eq!(display_val(batches).await, (DataType::Boolean, "false".to_string()));
2707+
2708+
let sql = r#"select json_str_contains('{"title": "Software Engineer"}', 'title', 'Engineer')"#;
2709+
let batches = run_query(sql).await.unwrap();
2710+
assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string()));
2711+
}
2712+
2713+
#[tokio::test]
2714+
async fn test_json_str_contains_partial_match() {
2715+
let sql = r#"select json_str_contains('{"email": "[email protected]"}', 'email', '@example')"#;
2716+
let batches = run_query(sql).await.unwrap();
2717+
assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string()));
2718+
}
2719+
2720+
#[tokio::test]
2721+
async fn test_json_str_contains_table() {
2722+
let expected = [
2723+
"+------------------+----------------------------------------------------------+",
2724+
"| name | json_str_contains(test.json_data,Utf8(\"foo\"),Utf8(\"ab\")) |",
2725+
"+------------------+----------------------------------------------------------+",
2726+
"| object_foo | true |",
2727+
"| object_foo_array | false |",
2728+
"| object_foo_obj | false |",
2729+
"| object_foo_null | false |",
2730+
"| object_bar | false |",
2731+
"| list_foo | false |",
2732+
"| invalid_json | false |",
2733+
"+------------------+----------------------------------------------------------+",
2734+
];
2735+
2736+
for_all_json_datatypes(async |dt| {
2737+
let batches = run_query_datatype("select name, json_str_contains(json_data, 'foo', 'ab') from test", dt)
2738+
.await
2739+
.unwrap();
2740+
assert_batches_eq!(expected, &batches);
2741+
})
2742+
.await;
2743+
}

0 commit comments

Comments
 (0)