|
| 1 | +use std::any::Any; |
| 2 | + |
| 3 | +use datafusion::arrow::array::BooleanArray; |
| 4 | +use datafusion::arrow::datatypes::DataType; |
| 5 | +use datafusion::common::{plan_err, Result}; |
| 6 | +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; |
| 7 | +use jiter::Peek; |
| 8 | + |
| 9 | +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; |
| 10 | +use crate::common_macros::make_udf_function; |
| 11 | + |
| 12 | +make_udf_function!( |
| 13 | + JsonStrContains, |
| 14 | + json_str_contains, |
| 15 | + json_data path needle, |
| 16 | + r#"Checks if a JSON string value at the specified path contains the given substring"# |
| 17 | +); |
| 18 | + |
| 19 | +#[derive(Debug, PartialEq, Eq, Hash)] |
| 20 | +pub(super) struct JsonStrContains { |
| 21 | + signature: Signature, |
| 22 | + aliases: [String; 1], |
| 23 | +} |
| 24 | + |
| 25 | +impl Default for JsonStrContains { |
| 26 | + fn default() -> Self { |
| 27 | + Self { |
| 28 | + signature: Signature::any(3, Volatility::Immutable), |
| 29 | + aliases: ["json_str_contains".to_string()], |
| 30 | + } |
| 31 | + } |
| 32 | +} |
| 33 | + |
| 34 | +impl ScalarUDFImpl for JsonStrContains { |
| 35 | + fn as_any(&self) -> &dyn Any { |
| 36 | + self |
| 37 | + } |
| 38 | + |
| 39 | + fn name(&self) -> &str { |
| 40 | + self.aliases[0].as_str() |
| 41 | + } |
| 42 | + |
| 43 | + fn signature(&self) -> &Signature { |
| 44 | + &self.signature |
| 45 | + } |
| 46 | + |
| 47 | + fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
| 48 | + if arg_types.len() != 3 { |
| 49 | + plan_err!("'json_str_contains' function requires exactly three arguments: json_data, path, and needle.") |
| 50 | + } else { |
| 51 | + return_type_check(arg_types, self.name(), DataType::Boolean).map(|_| DataType::Boolean) |
| 52 | + } |
| 53 | + } |
| 54 | + |
| 55 | + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { |
| 56 | + invoke::<BooleanArray>(&args.args, jiter_json_str_contains) |
| 57 | + } |
| 58 | + |
| 59 | + fn aliases(&self) -> &[String] { |
| 60 | + &self.aliases |
| 61 | + } |
| 62 | +} |
| 63 | + |
| 64 | +fn jiter_json_str_contains(json_data: Option<&str>, path: &[JsonPath]) -> Result<bool, GetError> { |
| 65 | + if path.len() != 2 { |
| 66 | + return get_err!(); |
| 67 | + } |
| 68 | + |
| 69 | + let (path_str, needle) = match path { |
| 70 | + [JsonPath::Key(path_str), JsonPath::Key(needle_str)] => (path_str, needle_str), |
| 71 | + _ => return get_err!(), |
| 72 | + }; |
| 73 | + |
| 74 | + let parsed_path = path_str.split('.').map(|s| JsonPath::Key(s)).collect::<Vec<_>>(); |
| 75 | + |
| 76 | + let Some((mut jiter, Peek::String)) = jiter_json_find(json_data, &parsed_path) else { |
| 77 | + return Ok(false); |
| 78 | + }; |
| 79 | + |
| 80 | + let str_value = jiter.known_str()?; |
| 81 | + |
| 82 | + Ok(str_value.contains(needle)) |
| 83 | +} |
| 84 | + |
| 85 | +#[cfg(test)] |
| 86 | +mod tests { |
| 87 | + use super::*; |
| 88 | + use datafusion::arrow::array::StringArray; |
| 89 | + use datafusion::common::ScalarValue; |
| 90 | + use datafusion::logical_expr::ColumnarValue; |
| 91 | + use std::sync::Arc; |
| 92 | + |
| 93 | + #[test] |
| 94 | + fn test_json_str_contains_simple() { |
| 95 | + let json_data = r#"{"name": "Norm Macdonald", "title": "Software Engineer"}"#; |
| 96 | + let json_array = Arc::new(StringArray::from(vec![json_data])); |
| 97 | + |
| 98 | + let args = vec![ |
| 99 | + ColumnarValue::Array(json_array), |
| 100 | + ColumnarValue::Scalar(ScalarValue::Utf8(Some("name".to_string()))), |
| 101 | + ColumnarValue::Scalar(ScalarValue::Utf8(Some("Norm".to_string()))), |
| 102 | + ]; |
| 103 | + |
| 104 | + let result = invoke::<BooleanArray>(&args, jiter_json_str_contains).unwrap(); |
| 105 | + |
| 106 | + match result { |
| 107 | + ColumnarValue::Array(arr) => { |
| 108 | + let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().unwrap(); |
| 109 | + assert_eq!(bool_arr.len(), 1); |
| 110 | + assert_eq!(bool_arr.value(0), true); |
| 111 | + } |
| 112 | + _ => panic!("Expected array result"), |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + #[test] |
| 117 | + fn test_json_str_contains_not_found() { |
| 118 | + let json_data = r#"{"name": "Norm Macdonald"}"#; |
| 119 | + let json_array = Arc::new(StringArray::from(vec![json_data])); |
| 120 | + |
| 121 | + let args = vec![ |
| 122 | + ColumnarValue::Array(json_array), |
| 123 | + ColumnarValue::Scalar(ScalarValue::Utf8(Some("name".to_string()))), |
| 124 | + ColumnarValue::Scalar(ScalarValue::Utf8(Some("Dave".to_string()))), |
| 125 | + ]; |
| 126 | + |
| 127 | + let result = invoke::<BooleanArray>(&args, jiter_json_str_contains).unwrap(); |
| 128 | + |
| 129 | + match result { |
| 130 | + ColumnarValue::Array(arr) => { |
| 131 | + let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().unwrap(); |
| 132 | + assert_eq!(bool_arr.len(), 1); |
| 133 | + assert_eq!(bool_arr.value(0), false); |
| 134 | + } |
| 135 | + _ => panic!("Expected array result"), |
| 136 | + } |
| 137 | + } |
| 138 | + |
| 139 | + #[test] |
| 140 | + fn test_json_str_contains_nested() { |
| 141 | + let json_data = r#"{"user": {"profile": {"name": "Norm Macdonald"}}}"#; |
| 142 | + let json_array = Arc::new(StringArray::from(vec![json_data])); |
| 143 | + |
| 144 | + let args = vec![ |
| 145 | + ColumnarValue::Array(json_array), |
| 146 | + ColumnarValue::Scalar(ScalarValue::Utf8(Some("user.profile.name".to_string()))), |
| 147 | + ColumnarValue::Scalar(ScalarValue::Utf8(Some("Macdonald".to_string()))), |
| 148 | + ]; |
| 149 | + |
| 150 | + let result = invoke::<BooleanArray>(&args, jiter_json_str_contains).unwrap(); |
| 151 | + |
| 152 | + match result { |
| 153 | + ColumnarValue::Array(arr) => { |
| 154 | + let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().unwrap(); |
| 155 | + assert_eq!(bool_arr.len(), 1); |
| 156 | + assert_eq!(bool_arr.value(0), true); |
| 157 | + } |
| 158 | + _ => panic!("Expected array result"), |
| 159 | + } |
| 160 | + } |
| 161 | +} |
0 commit comments