From 4866601318d1dadd14c507853336c96495a75922 Mon Sep 17 00:00:00 2001 From: Sean Griffin Date: Thu, 28 May 2026 12:28:33 -0600 Subject: [PATCH 1/2] Error when encountering an unsupported aggregate function When handling aggregate functions in cross-shard queries, we can only return a correct result for functions that we have explicit support for. No matter what the aggregate function is doing, we need to know how to combine the results from the separate shards. Prior to this change, we would just silently do the wrong thing, treating this as any other non-aggregate expression and just returning a union of the rows from each query. We now explicitly check if an aggregate function with that name exists and error if we don't recognize it. This is future proofed against new functions in later postgres versions, as well as user defined aggregates (which we can likely never support). This will produce a false positive if a function exists and is defined as aggregate for some argument types but not others. But frankly anyone doing that is asking for trouble. In theory we could potentially not error on direct-to-shard queries, but ultimately handling that proved to be a massive rabbit hole. It's something I'd like to do in the future after more refactoring, but for the time being it seems reasonable to assume that if `SELECT mysum(...)` fails, `SELECT mysum(...) WHERE shardkey = 1` would also fail --- integration/ci/cache-key.sh | 1 + integration/go/go_pgx/sharded_test.go | 6 +-- integration/rust/tests/integration/mod.rs | 1 + .../integration/unrecognized_aggregate.rs | 46 +++++++++++++++++ mise.toml | 4 +- pgdog-stats/src/schema.rs | 8 ++- .../src/backend/pool/connection/aggregate.rs | 10 +++- pgdog/src/backend/schema/mod.rs | 39 +++++++++++++++ pgdog/src/frontend/router/parser/aggregate.rs | 49 +++++++++++++++++-- .../frontend/router/parser/query/select.rs | 2 +- .../rewrite/statement/aggregate/engine.rs | 8 +-- .../parser/rewrite/statement/aggregate/mod.rs | 9 +++- .../router/parser/rewrite/statement/mod.rs | 2 +- rust-toolchain.toml | 4 ++ 14 files changed, 169 insertions(+), 20 deletions(-) create mode 100644 integration/rust/tests/integration/unrecognized_aggregate.rs create mode 100644 rust-toolchain.toml diff --git a/integration/ci/cache-key.sh b/integration/ci/cache-key.sh index 0140a721e..0f2289235 100755 --- a/integration/ci/cache-key.sh +++ b/integration/ci/cache-key.sh @@ -13,6 +13,7 @@ REPO_ROOT="$( cd "${SCRIPT_DIR}/../.." && pwd )" cd "$REPO_ROOT" files=( + rust-toolchain.toml Cargo.lock Cargo.toml .cargo/config.toml diff --git a/integration/go/go_pgx/sharded_test.go b/integration/go/go_pgx/sharded_test.go index 86d4a6d31..28bd48b1f 100644 --- a/integration/go/go_pgx/sharded_test.go +++ b/integration/go/go_pgx/sharded_test.go @@ -152,11 +152,11 @@ func TestShardedTwoPc(t *testing.T) { assert.NoError(t, err) } - // +4 is for schema sync + // +5 is for schema sync assertShowField(t, "SHOW STATS", "total_xact_2pc_count", 200, "pgdog_2pc", "pgdog_sharded", 0, "primary") assertShowField(t, "SHOW STATS", "total_xact_2pc_count", 200, "pgdog_2pc", "pgdog_sharded", 1, "primary") - assertShowField(t, "SHOW STATS", "total_xact_count", 401+4, "pgdog_2pc", "pgdog_sharded", 0, "primary") // PREPARE, COMMIT for each transaction + TRUNCATE - assertShowField(t, "SHOW STATS", "total_xact_count", 401+4, "pgdog_2pc", "pgdog_sharded", 1, "primary") + assertShowField(t, "SHOW STATS", "total_xact_count", 401+5, "pgdog_2pc", "pgdog_sharded", 0, "primary") // PREPARE, COMMIT for each transaction + TRUNCATE + assertShowField(t, "SHOW STATS", "total_xact_count", 401+5, "pgdog_2pc", "pgdog_sharded", 1, "primary") for i := range 200 { rows, err := conn.Query( diff --git a/integration/rust/tests/integration/mod.rs b/integration/rust/tests/integration/mod.rs index 985c9ddb1..c075a5db4 100644 --- a/integration/rust/tests/integration/mod.rs +++ b/integration/rust/tests/integration/mod.rs @@ -36,3 +36,4 @@ pub mod timestamp_sorting; pub mod tls_enforced; pub mod tls_reload; pub mod transaction_state; +pub mod unrecognized_aggregate; diff --git a/integration/rust/tests/integration/unrecognized_aggregate.rs b/integration/rust/tests/integration/unrecognized_aggregate.rs new file mode 100644 index 000000000..18e4e15ce --- /dev/null +++ b/integration/rust/tests/integration/unrecognized_aggregate.rs @@ -0,0 +1,46 @@ +use rust::setup::{admin_sqlx, connections_sqlx}; +use sqlx::Executor; +use std::assert_matches; + +async fn define_custom_aggregate_fn() { + for connection in connections_sqlx().await { + connection + .execute("DROP AGGREGATE IF EXISTS pgdog_sum (int4)") + .await + .unwrap(); + connection + .execute("CREATE AGGREGATE pgdog_sum (int4) (sfunc = int4_sum, stype = bigint)") + .await + .unwrap(); + connection + .execute("DROP TABLE IF EXISTS unrecognized_agg_test") + .await + .unwrap(); + connection + .execute("CREATE TABLE unrecognized_agg_test (lol int4, customer_id bigint)") + .await + .unwrap(); + } + admin_sqlx().await.execute("RELOAD").await.unwrap(); +} + +#[tokio::test] +async fn unrecognized_aggregate_function_errors_only_on_cross_shard_queries() { + define_custom_aggregate_fn().await; + let mut connections = connections_sqlx().await; + let sharded = connections.pop().unwrap(); + let unsharded = connections.pop().unwrap(); + + let unsharded_query = unsharded + .fetch_one("SELECT pgdog_sum(lol) FROM unrecognized_agg_test") + .await; + assert_matches!(unsharded_query, Ok(_)); + + let sharded_query = sharded + .fetch_one("SELECT pgdog_sum(lol) FROM unrecognized_agg_test") + .await; + let err = sharded_query + .err() + .expect("unrecognized aggregate executed successfully"); + assert!(err.to_string().contains("pgdog_sum is not yet supported")); +} diff --git a/mise.toml b/mise.toml index fe216b4a4..bc3afd39d 100644 --- a/mise.toml +++ b/mise.toml @@ -1,4 +1,4 @@ [tools] -"rust" = { version = "1.91.0" } +"rust" = { version = "1.96.0" } "cargo:cargo-nextest" = "latest" -"cargo:cargo-watch" = "latest" \ No newline at end of file +"cargo:cargo-watch" = "latest" diff --git a/pgdog-stats/src/schema.rs b/pgdog-stats/src/schema.rs index 1ceb88278..ee1844919 100644 --- a/pgdog-stats/src/schema.rs +++ b/pgdog-stats/src/schema.rs @@ -1,6 +1,11 @@ use indexmap::IndexMap; use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, hash::Hash, ops::Deref, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + hash::Hash, + ops::Deref, + sync::Arc, +}; /// Schema name -> Table name -> Relation pub type Relations = HashMap>; @@ -134,6 +139,7 @@ impl Relation { pub struct SchemaInner { pub search_path: Vec, pub relations: Relations, + pub aggregate_functions: HashSet, } impl Hash for SchemaInner { diff --git a/pgdog/src/backend/pool/connection/aggregate.rs b/pgdog/src/backend/pool/connection/aggregate.rs index 0d1421a68..73912e81a 100644 --- a/pgdog/src/backend/pool/connection/aggregate.rs +++ b/pgdog/src/backend/pool/connection/aggregate.rs @@ -176,6 +176,7 @@ impl<'a> Accumulator<'a> { } } } + AggregateFunction::Unrecognized(..) => return Ok(false), } Ok(true) @@ -512,6 +513,13 @@ impl<'a> Aggregates<'a> { false }) } + AggregateFunction::Unrecognized(fname) => { + unsupported.get_or_insert(UnsupportedAggregate { + function: fname.clone(), + reason: format!("{fname} is not yet supported"), + }); + false + } _ => true, } }); @@ -803,7 +811,7 @@ mod test { } fn parse(stmt: &str) -> Aggregate { - Aggregate::parse(&select(stmt)) + Aggregate::parse(&select(stmt), &Default::default()) } #[test] diff --git a/pgdog/src/backend/schema/mod.rs b/pgdog/src/backend/schema/mod.rs index b836fcdf8..cda7da25c 100644 --- a/pgdog/src/backend/schema/mod.rs +++ b/pgdog/src/backend/schema/mod.rs @@ -60,9 +60,18 @@ impl Schema { .map(|p| p.trim().replace("\"", "")) .collect(); + let aggregate_functions = server + .fetch_all::( + "SELECT DISTINCT proname FROM pg_proc INNER JOIN pg_aggregate ON oid = aggfnoid", + ) + .await? + .into_iter() + .collect(); + let inner = SchemaInner { search_path, relations, + aggregate_functions, }; Ok(Self { @@ -79,6 +88,15 @@ impl Schema { pub(crate) fn from_parts( search_path: Vec, relations: HashMap<(String, String), Relation>, + ) -> Self { + Self::from_parts_with_agg(search_path, relations, Vec::new()) + } + + #[cfg(test)] + pub(crate) fn from_parts_with_agg( + search_path: Vec, + relations: HashMap<(String, String), Relation>, + aggregate_functions: Vec, ) -> Self { let mut nested: StatsRelations = HashMap::new(); for ((schema, name), relation) in relations { @@ -91,6 +109,7 @@ impl Schema { inner: StatsSchema::new(SchemaInner { search_path, relations: nested, + aggregate_functions: aggregate_functions.into_iter().collect(), }), } } @@ -555,4 +574,24 @@ mod test { ); } } + + #[tokio::test] + async fn test_loading_aggregate_functions() { + let mut server = test_server().await; + server.execute_checked("BEGIN").await.unwrap(); + + let schema = Schema::load(&mut server).await.unwrap(); + assert!(!schema + .aggregate_functions + .contains(&String::from("pgdog_sum"))); + + server + .execute_checked("CREATE AGGREGATE pgdog_sum (int4) (sfunc = int4_sum, stype = bigint)") + .await + .unwrap(); + let schema = Schema::load(&mut server).await.unwrap(); + assert!(schema + .aggregate_functions + .contains(&String::from("pgdog_sum"))); + } } diff --git a/pgdog/src/frontend/router/parser/aggregate.rs b/pgdog/src/frontend/router/parser/aggregate.rs index fd88a4410..46167e066 100644 --- a/pgdog/src/frontend/router/parser/aggregate.rs +++ b/pgdog/src/frontend/router/parser/aggregate.rs @@ -2,7 +2,8 @@ use pg_query::protobuf::Integer; use pg_query::protobuf::{a_const::Val, Node, SelectStmt, String as PgQueryString}; use pg_query::NodeEnum; -use crate::frontend::router::parser::{ExpressionRegistry, Function}; +use super::{ExpressionRegistry, Function}; +use crate::backend::schema::Schema; #[derive(Debug, Clone, PartialEq)] pub struct AggregateTarget { @@ -41,10 +42,11 @@ pub enum AggregateFunction { StddevSamp, VarPop, VarSamp, + Unrecognized(String), } impl AggregateFunction { - pub fn as_str(&self) -> &'static str { + pub fn as_str(&self) -> &str { match self { AggregateFunction::Count => "count", AggregateFunction::Max => "max", @@ -55,6 +57,7 @@ impl AggregateFunction { AggregateFunction::StddevSamp => "stddev_samp", AggregateFunction::VarPop => "var_pop", AggregateFunction::VarSamp => "var_samp", + AggregateFunction::Unrecognized(s) => &*s, } } } @@ -121,7 +124,7 @@ fn columns_match(group_by_names: &[&String], select_names: &[&String]) -> bool { impl Aggregate { /// Figure out what aggregates are present and which ones PgDog supports. - pub fn parse(stmt: &SelectStmt) -> Self { + pub fn parse(stmt: &SelectStmt, schema: &Schema) -> Self { let mut targets = vec![]; let mut registry = ExpressionRegistry::new(); let group_by = stmt @@ -168,7 +171,13 @@ impl Aggregate { "stddev_pop" => Some(AggregateFunction::StddevPop), "variance" | "var_samp" => Some(AggregateFunction::VarSamp), "var_pop" => Some(AggregateFunction::VarPop), - _ => None, + fname => { + if schema.aggregate_functions.contains(fname) { + Some(AggregateFunction::Unrecognized(fname.to_owned())) + } else { + None + } + } }; if let Some(function) = function { @@ -259,7 +268,7 @@ mod test { } fn parse(stmt: &str) -> Aggregate { - Aggregate::parse(&select(stmt)) + Aggregate::parse(&select(stmt), &Default::default()) } #[test] @@ -455,4 +464,34 @@ mod test { assert_eq!(aggr.group_by(), &[0]); assert_eq!(aggr.targets().len(), 1); } + + #[test] + fn test_unrecognized_aggregate_function_errors() { + let schema_with_agg = Schema::from_parts_with_agg( + Vec::new(), + Default::default(), + vec![String::from("mysum")], + ); + let schema_without_agg = Default::default(); + let query = select("SELECT mysum(lol) FROM example"); + + // A random function that isn't listed as aggregate in the schema + // doesn't require special support on our end, so we should be fine. + let aggregate = Aggregate::parse(&query, &schema_without_agg); + assert_eq!(aggregate.targets, Vec::new()); + + // If we see an aggregate function we don't recognize, we can't + // process the query correctly, since we need to combine the + // results from each shard. + let aggregate = Aggregate::parse(&query, &schema_with_agg); + let funcs = aggregate + .targets + .into_iter() + .map(|t| t.function) + .collect::>(); + assert_eq!( + funcs, + vec![AggregateFunction::Unrecognized("mysum".to_owned())] + ); + } } diff --git a/pgdog/src/frontend/router/parser/query/select.rs b/pgdog/src/frontend/router/parser/query/select.rs index bcfc38370..b1ce45d15 100644 --- a/pgdog/src/frontend/router/parser/query/select.rs +++ b/pgdog/src/frontend/router/parser/query/select.rs @@ -138,7 +138,7 @@ impl QueryParser { } let shard = Self::converge(&shards, ConvergeAlgorithm::default()); - let aggregates = Aggregate::parse(stmt); + let aggregates = Aggregate::parse(stmt, &context.router_context.schema); let limit = LimitClause::new(stmt, context.router_context.bind).limit_offset()?; let distinct = Distinct::new(stmt).distinct()?; diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs index 82e804c15..70b4a30c2 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs @@ -277,7 +277,7 @@ mod tests { fn rewrite(sql: &str) -> (ParseResult, RewriteOutput) { let mut parsed = pg_query::parse(sql).unwrap().protobuf; let stmt = select(&mut parsed); - let aggregate = Aggregate::parse(stmt); + let aggregate = Aggregate::parse(stmt, &Default::default()); let output = AggregatesRewrite.rewrite_select(stmt, &aggregate); (parsed, output) } @@ -302,7 +302,7 @@ mod tests { assert!(!helper.distinct); assert!(matches!(helper.kind, HelperKind::Count)); - let aggregate = Aggregate::parse(select(&mut ast)); + let aggregate = Aggregate::parse(select(&mut ast), &Default::default()); assert_eq!(aggregate.targets().len(), 2); assert!(aggregate .targets() @@ -321,7 +321,7 @@ mod tests { assert!(!helper.distinct); assert!(matches!(helper.kind, HelperKind::Count)); - let aggregate = Aggregate::parse(select(&mut ast)); + let aggregate = Aggregate::parse(select(&mut ast), &Default::default()); assert_eq!(aggregate.targets().len(), 3); assert!( aggregate @@ -349,7 +349,7 @@ mod tests { assert_eq!(helper_discount.helper_column, 3); assert!(matches!(helper_discount.kind, HelperKind::Count)); - let aggregate = Aggregate::parse(select(&mut ast)); + let aggregate = Aggregate::parse(select(&mut ast), &Default::default()); assert_eq!(aggregate.targets().len(), 4); assert_eq!( aggregate diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs index c3c73b644..3d25d9df3 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs @@ -2,6 +2,7 @@ pub mod engine; pub mod plan; use super::{Error, RewritePlan, StatementRewrite}; +use crate::backend::schema::Schema; use crate::frontend::router::parser::aggregate::Aggregate; use pg_query::NodeEnum; @@ -10,7 +11,11 @@ pub use plan::{AggregateRewritePlan, HelperKind, HelperMapping, RewriteOutput}; impl StatementRewrite<'_> { /// Add missing COUNT(*) and other helps when using aggregates. - pub(super) fn rewrite_aggregates(&mut self, plan: &mut RewritePlan) -> Result<(), Error> { + pub(super) fn rewrite_aggregates( + &mut self, + plan: &mut RewritePlan, + schema: &Schema, + ) -> Result<(), Error> { if self.schema.shards == 1 { return Ok(()); } @@ -27,7 +32,7 @@ impl StatementRewrite<'_> { return Ok(()); }; - let aggregate = Aggregate::parse(select); + let aggregate = Aggregate::parse(select, schema); if aggregate.is_empty() { return Ok(()); } diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs b/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs index b03fab315..45eab440a 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs @@ -140,7 +140,7 @@ impl<'a> StatementRewrite<'a> { } })?; - self.rewrite_aggregates(&mut plan)?; + self.rewrite_aggregates(&mut plan, self.db_schema)?; self.limit_offset(&mut plan)?; if self.rewritten { diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 000000000..f2b26a5df --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,4 @@ +[toolchain] +channel = "1.96" +components = ["rustfmt", "clippy"] +profile = "default" From e12abe4dfb8fe2653c063e1f1d85bf38bbfe179c Mon Sep 17 00:00:00 2001 From: Sean Griffin Date: Fri, 5 Jun 2026 11:02:31 -0600 Subject: [PATCH 2/2] add python sharded session test --- integration/python/test_session_mode.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/integration/python/test_session_mode.py b/integration/python/test_session_mode.py index 5d4f344cf..85cd473a5 100644 --- a/integration/python/test_session_mode.py +++ b/integration/python/test_session_mode.py @@ -356,3 +356,14 @@ async def test_no_search_path_session_mode(): async with conn.transaction(): await conn.execute("SELECT * FROM shard_0.py_test_no_search_path_session_mode LIMIT 1") + +@pytest.mark.asyncio +async def test_unrecognized_aggregate_function_works_on_schema_based_sharding(): + conn = await async_session_conn("shard_0") + await conn.execute("DROP AGGREGATE IF EXISTS pgdog_sum_python(int4)") + await conn.execute("CREATE AGGREGATE pgdog_sum_python (int4) (sfunc = int4_sum, stype = bigint)") + await conn.execute("DROP TABLE IF EXISTS unrecognized_agg_test_python") + await conn.execute("CREATE TABLE unrecognized_agg_test_python (lol int4)") + + async with conn.transaction(): + await conn.execute("SELECT pgdog_sum_python(lol) FROM unrecognized_agg_test_python")