Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration/ci/cache-key.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ REPO_ROOT="$( cd "${SCRIPT_DIR}/../.." && pwd )"
cd "$REPO_ROOT"

files=(
rust-toolchain.toml
Cargo.lock
Cargo.toml
.cargo/config.toml
Expand Down
6 changes: 3 additions & 3 deletions integration/go/go_pgx/sharded_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ func TestShardedTwoPc(t *testing.T) {
assert.NoError(t, err)
}

// +4 is for schema sync
// +5 is for schema sync
assertShowField(t, "SHOW STATS", "total_xact_2pc_count", 200, "pgdog_2pc", "pgdog_sharded", 0, "primary")
assertShowField(t, "SHOW STATS", "total_xact_2pc_count", 200, "pgdog_2pc", "pgdog_sharded", 1, "primary")
assertShowField(t, "SHOW STATS", "total_xact_count", 401+4, "pgdog_2pc", "pgdog_sharded", 0, "primary") // PREPARE, COMMIT for each transaction + TRUNCATE
assertShowField(t, "SHOW STATS", "total_xact_count", 401+4, "pgdog_2pc", "pgdog_sharded", 1, "primary")
assertShowField(t, "SHOW STATS", "total_xact_count", 401+5, "pgdog_2pc", "pgdog_sharded", 0, "primary") // PREPARE, COMMIT for each transaction + TRUNCATE
assertShowField(t, "SHOW STATS", "total_xact_count", 401+5, "pgdog_2pc", "pgdog_sharded", 1, "primary")

for i := range 200 {
rows, err := conn.Query(
Expand Down
11 changes: 11 additions & 0 deletions integration/python/test_session_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,14 @@ async def test_no_search_path_session_mode():

async with conn.transaction():
await conn.execute("SELECT * FROM shard_0.py_test_no_search_path_session_mode LIMIT 1")

@pytest.mark.asyncio
async def test_unrecognized_aggregate_function_works_on_schema_based_sharding():
conn = await async_session_conn("shard_0")
await conn.execute("DROP AGGREGATE IF EXISTS pgdog_sum_python(int4)")
await conn.execute("CREATE AGGREGATE pgdog_sum_python (int4) (sfunc = int4_sum, stype = bigint)")
await conn.execute("DROP TABLE IF EXISTS unrecognized_agg_test_python")
await conn.execute("CREATE TABLE unrecognized_agg_test_python (lol int4)")

async with conn.transaction():
await conn.execute("SELECT pgdog_sum_python(lol) FROM unrecognized_agg_test_python")
1 change: 1 addition & 0 deletions integration/rust/tests/integration/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ pub mod timestamp_sorting;
pub mod tls_enforced;
pub mod tls_reload;
pub mod transaction_state;
pub mod unrecognized_aggregate;
46 changes: 46 additions & 0 deletions integration/rust/tests/integration/unrecognized_aggregate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use rust::setup::{admin_sqlx, connections_sqlx};
use sqlx::Executor;
use std::assert_matches;

async fn define_custom_aggregate_fn() {
for connection in connections_sqlx().await {
connection
.execute("DROP AGGREGATE IF EXISTS pgdog_sum (int4)")
.await
.unwrap();
connection
.execute("CREATE AGGREGATE pgdog_sum (int4) (sfunc = int4_sum, stype = bigint)")
.await
.unwrap();
connection
.execute("DROP TABLE IF EXISTS unrecognized_agg_test")
.await
.unwrap();
connection
.execute("CREATE TABLE unrecognized_agg_test (lol int4, customer_id bigint)")
.await
.unwrap();
}
admin_sqlx().await.execute("RELOAD").await.unwrap();
}

#[tokio::test]
async fn unrecognized_aggregate_function_errors_only_on_cross_shard_queries() {
define_custom_aggregate_fn().await;
let mut connections = connections_sqlx().await;
let sharded = connections.pop().unwrap();
let unsharded = connections.pop().unwrap();

let unsharded_query = unsharded
.fetch_one("SELECT pgdog_sum(lol) FROM unrecognized_agg_test")
.await;
assert_matches!(unsharded_query, Ok(_));

let sharded_query = sharded
.fetch_one("SELECT pgdog_sum(lol) FROM unrecognized_agg_test")
.await;
let err = sharded_query
.err()
.expect("unrecognized aggregate executed successfully");
assert!(err.to_string().contains("pgdog_sum is not yet supported"));
}
4 changes: 2 additions & 2 deletions mise.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[tools]
"rust" = { version = "1.91.0" }
"rust" = { version = "1.96.0" }
"cargo:cargo-nextest" = "latest"
"cargo:cargo-watch" = "latest"
"cargo:cargo-watch" = "latest"
8 changes: 7 additions & 1 deletion pgdog-stats/src/schema.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
use indexmap::IndexMap;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, hash::Hash, ops::Deref, sync::Arc};
use std::{
collections::{HashMap, HashSet},
hash::Hash,
ops::Deref,
sync::Arc,
};

/// Schema name -> Table name -> Relation
pub type Relations = HashMap<String, HashMap<String, Relation>>;
Expand Down Expand Up @@ -134,6 +139,7 @@ impl Relation {
pub struct SchemaInner {
pub search_path: Vec<String>,
pub relations: Relations,
pub aggregate_functions: HashSet<String>,
}

impl Hash for SchemaInner {
Expand Down
10 changes: 9 additions & 1 deletion pgdog/src/backend/pool/connection/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ impl<'a> Accumulator<'a> {
}
}
}
AggregateFunction::Unrecognized(..) => return Ok(false),
}

Ok(true)
Expand Down Expand Up @@ -512,6 +513,13 @@ impl<'a> Aggregates<'a> {
false
})
}
AggregateFunction::Unrecognized(fname) => {
unsupported.get_or_insert(UnsupportedAggregate {
function: fname.clone(),
reason: format!("{fname} is not yet supported"),
});
false
}
_ => true,
}
});
Expand Down Expand Up @@ -803,7 +811,7 @@ mod test {
}

fn parse(stmt: &str) -> Aggregate {
Aggregate::parse(&select(stmt))
Aggregate::parse(&select(stmt), &Default::default())
}

#[test]
Expand Down
39 changes: 39 additions & 0 deletions pgdog/src/backend/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,18 @@ impl Schema {
.map(|p| p.trim().replace("\"", ""))
.collect();

let aggregate_functions = server
.fetch_all::<String>(
"SELECT DISTINCT proname FROM pg_proc INNER JOIN pg_aggregate ON oid = aggfnoid",
)
.await?
.into_iter()
.collect();

let inner = SchemaInner {
search_path,
relations,
aggregate_functions,
};

Ok(Self {
Expand All @@ -79,6 +88,15 @@ impl Schema {
pub(crate) fn from_parts(
search_path: Vec<String>,
relations: HashMap<(String, String), Relation>,
) -> Self {
Self::from_parts_with_agg(search_path, relations, Vec::new())
}

#[cfg(test)]
pub(crate) fn from_parts_with_agg(
search_path: Vec<String>,
relations: HashMap<(String, String), Relation>,
aggregate_functions: Vec<String>,
) -> Self {
let mut nested: StatsRelations = HashMap::new();
for ((schema, name), relation) in relations {
Expand All @@ -91,6 +109,7 @@ impl Schema {
inner: StatsSchema::new(SchemaInner {
search_path,
relations: nested,
aggregate_functions: aggregate_functions.into_iter().collect(),
}),
}
}
Expand Down Expand Up @@ -555,4 +574,24 @@ mod test {
);
}
}

#[tokio::test]
async fn test_loading_aggregate_functions() {
let mut server = test_server().await;
server.execute_checked("BEGIN").await.unwrap();

let schema = Schema::load(&mut server).await.unwrap();
assert!(!schema
.aggregate_functions
.contains(&String::from("pgdog_sum")));

server
.execute_checked("CREATE AGGREGATE pgdog_sum (int4) (sfunc = int4_sum, stype = bigint)")
.await
.unwrap();
let schema = Schema::load(&mut server).await.unwrap();
assert!(schema
.aggregate_functions
.contains(&String::from("pgdog_sum")));
}
}
49 changes: 44 additions & 5 deletions pgdog/src/frontend/router/parser/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use pg_query::protobuf::Integer;
use pg_query::protobuf::{a_const::Val, Node, SelectStmt, String as PgQueryString};
use pg_query::NodeEnum;

use crate::frontend::router::parser::{ExpressionRegistry, Function};
use super::{ExpressionRegistry, Function};
use crate::backend::schema::Schema;

#[derive(Debug, Clone, PartialEq)]
pub struct AggregateTarget {
Expand Down Expand Up @@ -41,10 +42,11 @@ pub enum AggregateFunction {
StddevSamp,
VarPop,
VarSamp,
Unrecognized(String),
}

impl AggregateFunction {
pub fn as_str(&self) -> &'static str {
pub fn as_str(&self) -> &str {
match self {
AggregateFunction::Count => "count",
AggregateFunction::Max => "max",
Expand All @@ -55,6 +57,7 @@ impl AggregateFunction {
AggregateFunction::StddevSamp => "stddev_samp",
AggregateFunction::VarPop => "var_pop",
AggregateFunction::VarSamp => "var_samp",
AggregateFunction::Unrecognized(s) => &*s,
}
}
}
Expand Down Expand Up @@ -121,7 +124,7 @@ fn columns_match(group_by_names: &[&String], select_names: &[&String]) -> bool {

impl Aggregate {
/// Figure out what aggregates are present and which ones PgDog supports.
pub fn parse(stmt: &SelectStmt) -> Self {
pub fn parse(stmt: &SelectStmt, schema: &Schema) -> Self {
let mut targets = vec![];
let mut registry = ExpressionRegistry::new();
let group_by = stmt
Expand Down Expand Up @@ -168,7 +171,13 @@ impl Aggregate {
"stddev_pop" => Some(AggregateFunction::StddevPop),
"variance" | "var_samp" => Some(AggregateFunction::VarSamp),
"var_pop" => Some(AggregateFunction::VarPop),
_ => None,
fname => {
if schema.aggregate_functions.contains(fname) {
Comment thread
sgrif marked this conversation as resolved.
Some(AggregateFunction::Unrecognized(fname.to_owned()))
} else {
None
}
}
};

if let Some(function) = function {
Expand Down Expand Up @@ -259,7 +268,7 @@ mod test {
}

fn parse(stmt: &str) -> Aggregate {
Aggregate::parse(&select(stmt))
Aggregate::parse(&select(stmt), &Default::default())
}

#[test]
Expand Down Expand Up @@ -455,4 +464,34 @@ mod test {
assert_eq!(aggr.group_by(), &[0]);
assert_eq!(aggr.targets().len(), 1);
}

#[test]
fn test_unrecognized_aggregate_function_errors() {
let schema_with_agg = Schema::from_parts_with_agg(
Vec::new(),
Default::default(),
vec![String::from("mysum")],
);
let schema_without_agg = Default::default();
let query = select("SELECT mysum(lol) FROM example");

// A random function that isn't listed as aggregate in the schema
// doesn't require special support on our end, so we should be fine.
let aggregate = Aggregate::parse(&query, &schema_without_agg);
assert_eq!(aggregate.targets, Vec::new());

// If we see an aggregate function we don't recognize, we can't
// process the query correctly, since we need to combine the
// results from each shard.
let aggregate = Aggregate::parse(&query, &schema_with_agg);
let funcs = aggregate
.targets
.into_iter()
.map(|t| t.function)
.collect::<Vec<_>>();
assert_eq!(
funcs,
vec![AggregateFunction::Unrecognized("mysum".to_owned())]
);
}
}
2 changes: 1 addition & 1 deletion pgdog/src/frontend/router/parser/query/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ impl QueryParser {
}

let shard = Self::converge(&shards, ConvergeAlgorithm::default());
let aggregates = Aggregate::parse(stmt);
let aggregates = Aggregate::parse(stmt, &context.router_context.schema);
let limit = LimitClause::new(stmt, context.router_context.bind).limit_offset()?;
let distinct = Distinct::new(stmt).distinct()?;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ mod tests {
fn rewrite(sql: &str) -> (ParseResult, RewriteOutput) {
let mut parsed = pg_query::parse(sql).unwrap().protobuf;
let stmt = select(&mut parsed);
let aggregate = Aggregate::parse(stmt);
let aggregate = Aggregate::parse(stmt, &Default::default());
let output = AggregatesRewrite.rewrite_select(stmt, &aggregate);
(parsed, output)
}
Expand All @@ -302,7 +302,7 @@ mod tests {
assert!(!helper.distinct);
assert!(matches!(helper.kind, HelperKind::Count));

let aggregate = Aggregate::parse(select(&mut ast));
let aggregate = Aggregate::parse(select(&mut ast), &Default::default());
assert_eq!(aggregate.targets().len(), 2);
assert!(aggregate
.targets()
Expand All @@ -321,7 +321,7 @@ mod tests {
assert!(!helper.distinct);
assert!(matches!(helper.kind, HelperKind::Count));

let aggregate = Aggregate::parse(select(&mut ast));
let aggregate = Aggregate::parse(select(&mut ast), &Default::default());
assert_eq!(aggregate.targets().len(), 3);
assert!(
aggregate
Expand Down Expand Up @@ -349,7 +349,7 @@ mod tests {
assert_eq!(helper_discount.helper_column, 3);
assert!(matches!(helper_discount.kind, HelperKind::Count));

let aggregate = Aggregate::parse(select(&mut ast));
let aggregate = Aggregate::parse(select(&mut ast), &Default::default());
assert_eq!(aggregate.targets().len(), 4);
assert_eq!(
aggregate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod engine;
pub mod plan;

use super::{Error, RewritePlan, StatementRewrite};
use crate::backend::schema::Schema;
use crate::frontend::router::parser::aggregate::Aggregate;
use pg_query::NodeEnum;

Expand All @@ -10,7 +11,11 @@ pub use plan::{AggregateRewritePlan, HelperKind, HelperMapping, RewriteOutput};

impl StatementRewrite<'_> {
/// Add missing COUNT(*) and other helps when using aggregates.
pub(super) fn rewrite_aggregates(&mut self, plan: &mut RewritePlan) -> Result<(), Error> {
pub(super) fn rewrite_aggregates(
&mut self,
plan: &mut RewritePlan,
schema: &Schema,
) -> Result<(), Error> {
if self.schema.shards == 1 {
return Ok(());
}
Expand All @@ -27,7 +32,7 @@ impl StatementRewrite<'_> {
return Ok(());
};

let aggregate = Aggregate::parse(select);
let aggregate = Aggregate::parse(select, schema);
if aggregate.is_empty() {
return Ok(());
}
Expand Down
Loading
Loading