diff --git a/integration/ci/cache-key.sh b/integration/ci/cache-key.sh index 0140a721e..0f2289235 100755 --- a/integration/ci/cache-key.sh +++ b/integration/ci/cache-key.sh @@ -13,6 +13,7 @@ REPO_ROOT="$( cd "${SCRIPT_DIR}/../.." && pwd )" cd "$REPO_ROOT" files=( + rust-toolchain.toml Cargo.lock Cargo.toml .cargo/config.toml diff --git a/integration/go/go_pgx/sharded_test.go b/integration/go/go_pgx/sharded_test.go index 86d4a6d31..28bd48b1f 100644 --- a/integration/go/go_pgx/sharded_test.go +++ b/integration/go/go_pgx/sharded_test.go @@ -152,11 +152,11 @@ func TestShardedTwoPc(t *testing.T) { assert.NoError(t, err) } - // +4 is for schema sync + // +5 is for schema sync assertShowField(t, "SHOW STATS", "total_xact_2pc_count", 200, "pgdog_2pc", "pgdog_sharded", 0, "primary") assertShowField(t, "SHOW STATS", "total_xact_2pc_count", 200, "pgdog_2pc", "pgdog_sharded", 1, "primary") - assertShowField(t, "SHOW STATS", "total_xact_count", 401+4, "pgdog_2pc", "pgdog_sharded", 0, "primary") // PREPARE, COMMIT for each transaction + TRUNCATE - assertShowField(t, "SHOW STATS", "total_xact_count", 401+4, "pgdog_2pc", "pgdog_sharded", 1, "primary") + assertShowField(t, "SHOW STATS", "total_xact_count", 401+5, "pgdog_2pc", "pgdog_sharded", 0, "primary") // PREPARE, COMMIT for each transaction + TRUNCATE + assertShowField(t, "SHOW STATS", "total_xact_count", 401+5, "pgdog_2pc", "pgdog_sharded", 1, "primary") for i := range 200 { rows, err := conn.Query( diff --git a/integration/python/test_session_mode.py b/integration/python/test_session_mode.py index 5d4f344cf..85cd473a5 100644 --- a/integration/python/test_session_mode.py +++ b/integration/python/test_session_mode.py @@ -356,3 +356,14 @@ async def test_no_search_path_session_mode(): async with conn.transaction(): await conn.execute("SELECT * FROM shard_0.py_test_no_search_path_session_mode LIMIT 1") + +@pytest.mark.asyncio +async def test_unrecognized_aggregate_function_works_on_schema_based_sharding(): + conn = await async_session_conn("shard_0") + await conn.execute("DROP AGGREGATE IF EXISTS pgdog_sum_python(int4)") + await conn.execute("CREATE AGGREGATE pgdog_sum_python (int4) (sfunc = int4_sum, stype = bigint)") + await conn.execute("DROP TABLE IF EXISTS unrecognized_agg_test_python") + await conn.execute("CREATE TABLE unrecognized_agg_test_python (lol int4)") + + async with conn.transaction(): + await conn.execute("SELECT pgdog_sum_python(lol) FROM unrecognized_agg_test_python") diff --git a/integration/rust/tests/integration/mod.rs b/integration/rust/tests/integration/mod.rs index 985c9ddb1..c075a5db4 100644 --- a/integration/rust/tests/integration/mod.rs +++ b/integration/rust/tests/integration/mod.rs @@ -36,3 +36,4 @@ pub mod timestamp_sorting; pub mod tls_enforced; pub mod tls_reload; pub mod transaction_state; +pub mod unrecognized_aggregate; diff --git a/integration/rust/tests/integration/unrecognized_aggregate.rs b/integration/rust/tests/integration/unrecognized_aggregate.rs new file mode 100644 index 000000000..18e4e15ce --- /dev/null +++ b/integration/rust/tests/integration/unrecognized_aggregate.rs @@ -0,0 +1,46 @@ +use rust::setup::{admin_sqlx, connections_sqlx}; +use sqlx::Executor; +use std::assert_matches; + +async fn define_custom_aggregate_fn() { + for connection in connections_sqlx().await { + connection + .execute("DROP AGGREGATE IF EXISTS pgdog_sum (int4)") + .await + .unwrap(); + connection + .execute("CREATE AGGREGATE pgdog_sum (int4) (sfunc = int4_sum, stype = bigint)") + .await + .unwrap(); + connection + .execute("DROP TABLE IF EXISTS unrecognized_agg_test") + .await + .unwrap(); + connection + .execute("CREATE TABLE unrecognized_agg_test (lol int4, customer_id bigint)") + .await + .unwrap(); + } + admin_sqlx().await.execute("RELOAD").await.unwrap(); +} + +#[tokio::test] +async fn unrecognized_aggregate_function_errors_only_on_cross_shard_queries() { + define_custom_aggregate_fn().await; + let mut connections = connections_sqlx().await; + let sharded = connections.pop().unwrap(); + let unsharded = connections.pop().unwrap(); + + let unsharded_query = unsharded + .fetch_one("SELECT pgdog_sum(lol) FROM unrecognized_agg_test") + .await; + assert_matches!(unsharded_query, Ok(_)); + + let sharded_query = sharded + .fetch_one("SELECT pgdog_sum(lol) FROM unrecognized_agg_test") + .await; + let err = sharded_query + .err() + .expect("unrecognized aggregate executed successfully"); + assert!(err.to_string().contains("pgdog_sum is not yet supported")); +} diff --git a/mise.toml b/mise.toml index fe216b4a4..bc3afd39d 100644 --- a/mise.toml +++ b/mise.toml @@ -1,4 +1,4 @@ [tools] -"rust" = { version = "1.91.0" } +"rust" = { version = "1.96.0" } "cargo:cargo-nextest" = "latest" -"cargo:cargo-watch" = "latest" \ No newline at end of file +"cargo:cargo-watch" = "latest" diff --git a/pgdog-stats/src/schema.rs b/pgdog-stats/src/schema.rs index 1ceb88278..ee1844919 100644 --- a/pgdog-stats/src/schema.rs +++ b/pgdog-stats/src/schema.rs @@ -1,6 +1,11 @@ use indexmap::IndexMap; use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, hash::Hash, ops::Deref, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + hash::Hash, + ops::Deref, + sync::Arc, +}; /// Schema name -> Table name -> Relation pub type Relations = HashMap>; @@ -134,6 +139,7 @@ impl Relation { pub struct SchemaInner { pub search_path: Vec, pub relations: Relations, + pub aggregate_functions: HashSet, } impl Hash for SchemaInner { diff --git a/pgdog/src/backend/pool/connection/aggregate.rs b/pgdog/src/backend/pool/connection/aggregate.rs index 0d1421a68..73912e81a 100644 --- a/pgdog/src/backend/pool/connection/aggregate.rs +++ b/pgdog/src/backend/pool/connection/aggregate.rs @@ -176,6 +176,7 @@ impl<'a> Accumulator<'a> { } } } + AggregateFunction::Unrecognized(..) => return Ok(false), } Ok(true) @@ -512,6 +513,13 @@ impl<'a> Aggregates<'a> { false }) } + AggregateFunction::Unrecognized(fname) => { + unsupported.get_or_insert(UnsupportedAggregate { + function: fname.clone(), + reason: format!("{fname} is not yet supported"), + }); + false + } _ => true, } }); @@ -803,7 +811,7 @@ mod test { } fn parse(stmt: &str) -> Aggregate { - Aggregate::parse(&select(stmt)) + Aggregate::parse(&select(stmt), &Default::default()) } #[test] diff --git a/pgdog/src/backend/schema/mod.rs b/pgdog/src/backend/schema/mod.rs index b836fcdf8..cda7da25c 100644 --- a/pgdog/src/backend/schema/mod.rs +++ b/pgdog/src/backend/schema/mod.rs @@ -60,9 +60,18 @@ impl Schema { .map(|p| p.trim().replace("\"", "")) .collect(); + let aggregate_functions = server + .fetch_all::( + "SELECT DISTINCT proname FROM pg_proc INNER JOIN pg_aggregate ON oid = aggfnoid", + ) + .await? + .into_iter() + .collect(); + let inner = SchemaInner { search_path, relations, + aggregate_functions, }; Ok(Self { @@ -79,6 +88,15 @@ impl Schema { pub(crate) fn from_parts( search_path: Vec, relations: HashMap<(String, String), Relation>, + ) -> Self { + Self::from_parts_with_agg(search_path, relations, Vec::new()) + } + + #[cfg(test)] + pub(crate) fn from_parts_with_agg( + search_path: Vec, + relations: HashMap<(String, String), Relation>, + aggregate_functions: Vec, ) -> Self { let mut nested: StatsRelations = HashMap::new(); for ((schema, name), relation) in relations { @@ -91,6 +109,7 @@ impl Schema { inner: StatsSchema::new(SchemaInner { search_path, relations: nested, + aggregate_functions: aggregate_functions.into_iter().collect(), }), } } @@ -555,4 +574,24 @@ mod test { ); } } + + #[tokio::test] + async fn test_loading_aggregate_functions() { + let mut server = test_server().await; + server.execute_checked("BEGIN").await.unwrap(); + + let schema = Schema::load(&mut server).await.unwrap(); + assert!(!schema + .aggregate_functions + .contains(&String::from("pgdog_sum"))); + + server + .execute_checked("CREATE AGGREGATE pgdog_sum (int4) (sfunc = int4_sum, stype = bigint)") + .await + .unwrap(); + let schema = Schema::load(&mut server).await.unwrap(); + assert!(schema + .aggregate_functions + .contains(&String::from("pgdog_sum"))); + } } diff --git a/pgdog/src/frontend/router/parser/aggregate.rs b/pgdog/src/frontend/router/parser/aggregate.rs index fd88a4410..46167e066 100644 --- a/pgdog/src/frontend/router/parser/aggregate.rs +++ b/pgdog/src/frontend/router/parser/aggregate.rs @@ -2,7 +2,8 @@ use pg_query::protobuf::Integer; use pg_query::protobuf::{a_const::Val, Node, SelectStmt, String as PgQueryString}; use pg_query::NodeEnum; -use crate::frontend::router::parser::{ExpressionRegistry, Function}; +use super::{ExpressionRegistry, Function}; +use crate::backend::schema::Schema; #[derive(Debug, Clone, PartialEq)] pub struct AggregateTarget { @@ -41,10 +42,11 @@ pub enum AggregateFunction { StddevSamp, VarPop, VarSamp, + Unrecognized(String), } impl AggregateFunction { - pub fn as_str(&self) -> &'static str { + pub fn as_str(&self) -> &str { match self { AggregateFunction::Count => "count", AggregateFunction::Max => "max", @@ -55,6 +57,7 @@ impl AggregateFunction { AggregateFunction::StddevSamp => "stddev_samp", AggregateFunction::VarPop => "var_pop", AggregateFunction::VarSamp => "var_samp", + AggregateFunction::Unrecognized(s) => &*s, } } } @@ -121,7 +124,7 @@ fn columns_match(group_by_names: &[&String], select_names: &[&String]) -> bool { impl Aggregate { /// Figure out what aggregates are present and which ones PgDog supports. - pub fn parse(stmt: &SelectStmt) -> Self { + pub fn parse(stmt: &SelectStmt, schema: &Schema) -> Self { let mut targets = vec![]; let mut registry = ExpressionRegistry::new(); let group_by = stmt @@ -168,7 +171,13 @@ impl Aggregate { "stddev_pop" => Some(AggregateFunction::StddevPop), "variance" | "var_samp" => Some(AggregateFunction::VarSamp), "var_pop" => Some(AggregateFunction::VarPop), - _ => None, + fname => { + if schema.aggregate_functions.contains(fname) { + Some(AggregateFunction::Unrecognized(fname.to_owned())) + } else { + None + } + } }; if let Some(function) = function { @@ -259,7 +268,7 @@ mod test { } fn parse(stmt: &str) -> Aggregate { - Aggregate::parse(&select(stmt)) + Aggregate::parse(&select(stmt), &Default::default()) } #[test] @@ -455,4 +464,34 @@ mod test { assert_eq!(aggr.group_by(), &[0]); assert_eq!(aggr.targets().len(), 1); } + + #[test] + fn test_unrecognized_aggregate_function_errors() { + let schema_with_agg = Schema::from_parts_with_agg( + Vec::new(), + Default::default(), + vec![String::from("mysum")], + ); + let schema_without_agg = Default::default(); + let query = select("SELECT mysum(lol) FROM example"); + + // A random function that isn't listed as aggregate in the schema + // doesn't require special support on our end, so we should be fine. + let aggregate = Aggregate::parse(&query, &schema_without_agg); + assert_eq!(aggregate.targets, Vec::new()); + + // If we see an aggregate function we don't recognize, we can't + // process the query correctly, since we need to combine the + // results from each shard. + let aggregate = Aggregate::parse(&query, &schema_with_agg); + let funcs = aggregate + .targets + .into_iter() + .map(|t| t.function) + .collect::>(); + assert_eq!( + funcs, + vec![AggregateFunction::Unrecognized("mysum".to_owned())] + ); + } } diff --git a/pgdog/src/frontend/router/parser/query/select.rs b/pgdog/src/frontend/router/parser/query/select.rs index bcfc38370..b1ce45d15 100644 --- a/pgdog/src/frontend/router/parser/query/select.rs +++ b/pgdog/src/frontend/router/parser/query/select.rs @@ -138,7 +138,7 @@ impl QueryParser { } let shard = Self::converge(&shards, ConvergeAlgorithm::default()); - let aggregates = Aggregate::parse(stmt); + let aggregates = Aggregate::parse(stmt, &context.router_context.schema); let limit = LimitClause::new(stmt, context.router_context.bind).limit_offset()?; let distinct = Distinct::new(stmt).distinct()?; diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs index 82e804c15..70b4a30c2 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs @@ -277,7 +277,7 @@ mod tests { fn rewrite(sql: &str) -> (ParseResult, RewriteOutput) { let mut parsed = pg_query::parse(sql).unwrap().protobuf; let stmt = select(&mut parsed); - let aggregate = Aggregate::parse(stmt); + let aggregate = Aggregate::parse(stmt, &Default::default()); let output = AggregatesRewrite.rewrite_select(stmt, &aggregate); (parsed, output) } @@ -302,7 +302,7 @@ mod tests { assert!(!helper.distinct); assert!(matches!(helper.kind, HelperKind::Count)); - let aggregate = Aggregate::parse(select(&mut ast)); + let aggregate = Aggregate::parse(select(&mut ast), &Default::default()); assert_eq!(aggregate.targets().len(), 2); assert!(aggregate .targets() @@ -321,7 +321,7 @@ mod tests { assert!(!helper.distinct); assert!(matches!(helper.kind, HelperKind::Count)); - let aggregate = Aggregate::parse(select(&mut ast)); + let aggregate = Aggregate::parse(select(&mut ast), &Default::default()); assert_eq!(aggregate.targets().len(), 3); assert!( aggregate @@ -349,7 +349,7 @@ mod tests { assert_eq!(helper_discount.helper_column, 3); assert!(matches!(helper_discount.kind, HelperKind::Count)); - let aggregate = Aggregate::parse(select(&mut ast)); + let aggregate = Aggregate::parse(select(&mut ast), &Default::default()); assert_eq!(aggregate.targets().len(), 4); assert_eq!( aggregate diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs index c3c73b644..3d25d9df3 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs @@ -2,6 +2,7 @@ pub mod engine; pub mod plan; use super::{Error, RewritePlan, StatementRewrite}; +use crate::backend::schema::Schema; use crate::frontend::router::parser::aggregate::Aggregate; use pg_query::NodeEnum; @@ -10,7 +11,11 @@ pub use plan::{AggregateRewritePlan, HelperKind, HelperMapping, RewriteOutput}; impl StatementRewrite<'_> { /// Add missing COUNT(*) and other helps when using aggregates. - pub(super) fn rewrite_aggregates(&mut self, plan: &mut RewritePlan) -> Result<(), Error> { + pub(super) fn rewrite_aggregates( + &mut self, + plan: &mut RewritePlan, + schema: &Schema, + ) -> Result<(), Error> { if self.schema.shards == 1 { return Ok(()); } @@ -27,7 +32,7 @@ impl StatementRewrite<'_> { return Ok(()); }; - let aggregate = Aggregate::parse(select); + let aggregate = Aggregate::parse(select, schema); if aggregate.is_empty() { return Ok(()); } diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs b/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs index b03fab315..45eab440a 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs @@ -140,7 +140,7 @@ impl<'a> StatementRewrite<'a> { } })?; - self.rewrite_aggregates(&mut plan)?; + self.rewrite_aggregates(&mut plan, self.db_schema)?; self.limit_offset(&mut plan)?; if self.rewritten { diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 000000000..f2b26a5df --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,4 @@ +[toolchain] +channel = "1.96" +components = ["rustfmt", "clippy"] +profile = "default"