diff --git a/integration/rust/tests/integration/timestamp_sorting.rs b/integration/rust/tests/integration/timestamp_sorting.rs index 591ea1cb1..446e6a730 100644 --- a/integration/rust/tests/integration/timestamp_sorting.rs +++ b/integration/rust/tests/integration/timestamp_sorting.rs @@ -120,18 +120,17 @@ async fn test_timestamp_sorting_across_shards() { .unwrap(); let rows = sharded_conn - .fetch_all( - "SELECT id, name, updated_at FROM timestamp_test ORDER BY updated_at DESC NULLS LAST", - ) + .fetch_all("SELECT id, name, updated_at FROM timestamp_test ORDER BY updated_at ASC") .await .unwrap(); - let last_rows: Vec> = rows + let last_rows: Vec>> = rows .iter() .rev() .take(2) - .map(|row| row.try_get(2).ok()) - .collect(); + .map(|row| row.try_get(2)) + .collect::, _>>() + .unwrap(); assert!( last_rows.iter().any(|v| v.is_none()), diff --git a/pgdog-postgres-types/src/array.rs b/pgdog-postgres-types/src/array.rs index 402ecb1bd..b7e865d21 100644 --- a/pgdog-postgres-types/src/array.rs +++ b/pgdog-postgres-types/src/array.rs @@ -3,7 +3,7 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; use super::{Error, Format}; use crate::{DataType, Datum}; -#[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Hash)] pub struct Array { elements: Vec>, element_oid: i32, diff --git a/pgdog-postgres-types/src/datum.rs b/pgdog-postgres-types/src/datum.rs index 8de553da7..4fc6a9567 100644 --- a/pgdog-postgres-types/src/datum.rs +++ b/pgdog-postgres-types/src/datum.rs @@ -1,3 +1,4 @@ +use std::cmp::{Ordering, PartialOrd}; use std::ops::Add; use bytes::Bytes; @@ -9,7 +10,10 @@ use crate::{ TimestampTz, ToDataRowColumn, }; -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +/// Represents a single piece of data in expression position. Trait +/// implementations for Rust operators match the semantics of that +/// operator/opclass in expression position in PG +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Datum { /// BIGINT. Bigint(i64), @@ -47,6 +51,48 @@ pub enum Datum { Boolean(bool), } +impl PartialOrd for Datum { + fn partial_cmp(&self, other: &Datum) -> Option { + use Datum::*; + + match (self, other) { + (Bigint(a), Bigint(b)) => a.partial_cmp(b), + (Bigint(_), _) | (_, Bigint(_)) => None, + (Integer(a), Integer(b)) => a.partial_cmp(b), + (Integer(_), _) | (_, Integer(_)) => None, + (SmallInt(a), SmallInt(b)) => a.partial_cmp(b), + (SmallInt(_), _) | (_, SmallInt(_)) => None, + (Interval(a), Interval(b)) => a.partial_cmp(b), + (Interval(_), _) | (_, Interval(_)) => None, + (Text(a), Text(b)) => a.partial_cmp(b), + (Text(_), _) | (_, Text(_)) => None, + (Timestamp(a), Timestamp(b)) => a.partial_cmp(b), + (Timestamp(_), _) | (_, Timestamp(_)) => None, + (TimestampTz(a), TimestampTz(b)) => a.partial_cmp(b), + (TimestampTz(_), _) | (_, TimestampTz(_)) => None, + (Uuid(a), Uuid(b)) => a.partial_cmp(b), + (Uuid(_), _) | (_, Uuid(_)) => None, + (Numeric(a), Numeric(b)) => a.partial_cmp(b), + (Numeric(_), _) | (_, Numeric(_)) => None, + (Float(a), Float(b)) => a.partial_cmp(b), + (Float(_), _) | (_, Float(_)) => None, + (Double(a), Double(b)) => a.partial_cmp(b), + (Double(_), _) | (_, Double(_)) => None, + (Vector(a), Vector(b)) => a.partial_cmp(b), + (Vector(_), _) | (_, Vector(_)) => None, + (Oid(a), Oid(b)) => a.partial_cmp(b), + (Oid(_), _) | (_, Oid(_)) => None, + (Array(a), Array(b)) => a.partial_cmp(b), + (Array(_), _) | (_, Array(_)) => None, + (Unknown(a), Unknown(b)) => a.partial_cmp(b), + (Unknown(_), _) | (_, Unknown(_)) => None, + (Boolean(a), Boolean(b)) => a.partial_cmp(b), + (Boolean(_), _) | (_, Boolean(_)) => None, + (Null, _) => None, + } + } +} + impl ToDataRowColumn for Datum { fn to_data_row_column(&self) -> Data { use Datum::*; diff --git a/pgdog/src/backend/pool/connection/aggregate.rs b/pgdog/src/backend/pool/connection/aggregate.rs index 1e070104f..d1a42e2c2 100644 --- a/pgdog/src/backend/pool/connection/aggregate.rs +++ b/pgdog/src/backend/pool/connection/aggregate.rs @@ -21,7 +21,7 @@ use rust_decimal::Decimal; use super::Error; /// GROUP BY -#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Debug)] +#[derive(Hash, PartialEq, Eq, Debug)] struct Grouping { columns: Vec<(usize, Datum)>, } diff --git a/pgdog/src/backend/pool/connection/buffer.rs b/pgdog/src/backend/pool/connection/buffer.rs index 20bf1a373..dcb074bc9 100644 --- a/pgdog/src/backend/pool/connection/buffer.rs +++ b/pgdog/src/backend/pool/connection/buffer.rs @@ -16,6 +16,8 @@ use crate::{ }, }; +use pgdog_postgres_types::Datum; + use super::Aggregates; /// Sort and aggregate rows received from multiple shards. @@ -80,48 +82,52 @@ impl Buffer { // Sort rows. let order_by = move |a: &DataRow, b: &DataRow| -> Ordering { - for col in cols.iter() { - let index = col.index(); - let asc = col.asc(); - let index = if let Some(index) = index { - index - } else { - continue; - }; - let left = a.get_column(index, decoder); - let right = b.get_column(index, decoder); - - let ordering = match (left, right) { - (Ok(Some(left)), Ok(Some(right))) => { - // Handle the special vector case. - if let OrderBy::AscVectorL2(_, vector) = col { - let left: Option = left.value.try_into().ok(); - let right: Option = right.value.try_into().ok(); - - if let (Some(left), Some(right)) = (left, right) { - let left = left.distance_l2(vector); - let right = right.distance_l2(vector); - - left.partial_cmp(&right) + cols.iter() + .filter_map(|col| { + let index = col.index(); + let asc = col.asc(); + let Some(index) = index else { + return None; + }; + let left = a.get_column(index, decoder); + let right = b.get_column(index, decoder); + + match (left, right) { + (Ok(Some(left)), Ok(Some(right))) => { + // Handle the special vector case. + if let OrderBy::AscVectorL2(_, vector) = col { + let left: Option = left.value.try_into().ok(); + let right: Option = right.value.try_into().ok(); + + if let (Some(left), Some(right)) = (left, right) { + let left = left.distance_l2(vector); + let right = right.distance_l2(vector); + + left.partial_cmp(&right) + } else { + Some(Ordering::Equal) + } } else { - Some(Ordering::Equal) + // FIXME(sage): We don't handle ASC NULLS FIRST or + // DESC NULLS LAST we should either error or add + // support rather than silently do the wrong sorting + match (&left.value, &right.value, asc) { + (Datum::Null, Datum::Null, _) => Some(Ordering::Equal), + (Datum::Null, _, true) => Some(Ordering::Greater), + (_, Datum::Null, true) => Some(Ordering::Less), + (Datum::Null, _, false) => Some(Ordering::Less), + (_, Datum::Null, false) => Some(Ordering::Greater), + (a, b, true) => a.partial_cmp(b), + (a, b, false) => b.partial_cmp(a), + } } - } else if asc { - left.value.partial_cmp(&right.value) - } else { - right.value.partial_cmp(&left.value) } - } - - _ => Some(Ordering::Equal), - }; - if ordering != Some(Ordering::Equal) { - return ordering.unwrap_or(Ordering::Equal); - } - } - - Ordering::Equal + _ => Some(Ordering::Equal), + } + }) + .reduce(Ordering::then) + .unwrap_or(Ordering::Equal) }; self.buffer.make_contiguous().sort_by(order_by);