Skip to content

Commit 8805604

Browse files
committed
[CALCITE-7439] Qualify GROUP BY keys for DISTINCT over joins
1 parent c628e68 commit 8805604

File tree

2 files changed

+141
-2
lines changed

2 files changed

+141
-2
lines changed

core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -858,12 +858,21 @@ private List<SqlNode> generateGroupList(Builder builder,
858858
+ aggregate.getGroupSet() + ", just possibly a different order";
859859

860860
final List<SqlNode> groupKeys = new ArrayList<>();
861+
final Join aggregateJoinInput =
862+
aggregate.getInput() instanceof Join ? (Join) aggregate.getInput() : null;
863+
final SqlJoin fromJoin =
864+
builder.select.getFrom() instanceof SqlJoin ? (SqlJoin) builder.select.getFrom() : null;
865+
final int leftFieldCount = aggregateJoinInput == null
866+
? -1
867+
: aggregateJoinInput.getLeft().getRowType().getFieldCount();
861868
for (int key : groupList) {
862-
final SqlNode field = builder.context.field(key);
869+
SqlNode field = builder.context.field(key);
870+
field = maybeQualifyJoinKey(field, key, fromJoin, leftFieldCount);
863871
groupKeys.add(field);
864872
}
865873
for (int key : sortedGroupList) {
866-
final SqlNode field = builder.context.field(key);
874+
SqlNode field =
875+
maybeQualifyJoinKey(builder.context.field(key), key, fromJoin, leftFieldCount);
867876
addSelect(selectList, field, aggregate.getRowType());
868877
}
869878
switch (aggregate.getGroupType()) {
@@ -905,6 +914,31 @@ private List<SqlNode> generateGroupList(Builder builder,
905914
}
906915
}
907916

917+
private SqlNode maybeQualifyJoinKey(SqlNode field, int key,
918+
@Nullable SqlJoin fromJoin, int leftFieldCount) {
919+
if (!(field instanceof SqlIdentifier)
920+
|| ((SqlIdentifier) field).names.size() != 1
921+
|| fromJoin == null) {
922+
return field;
923+
}
924+
925+
final SqlNode side;
926+
if (leftFieldCount < 0) {
927+
if (key != 0) {
928+
return field;
929+
}
930+
side = fromJoin.getLeft();
931+
} else {
932+
side = key < leftFieldCount ? fromJoin.getLeft() : fromJoin.getRight();
933+
}
934+
final String sideAlias = SqlValidatorUtil.alias(side);
935+
if (sideAlias == null) {
936+
return field;
937+
}
938+
939+
return new SqlIdentifier(ImmutableList.of(sideAlias, ((SqlIdentifier) field).getSimple()), POS);
940+
}
941+
908942
private static SqlNode groupItem(List<SqlNode> groupKeys,
909943
ImmutableBitSet groupSet, ImmutableBitSet wholeGroupSet) {
910944
final List<SqlNode> nodes = groupSet.asList().stream()

core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@
127127
import static org.hamcrest.CoreMatchers.is;
128128
import static org.hamcrest.CoreMatchers.notNullValue;
129129
import static org.hamcrest.MatcherAssert.assertThat;
130+
import static org.hamcrest.Matchers.containsString;
130131
import static org.hamcrest.Matchers.hasToString;
132+
import static org.hamcrest.Matchers.not;
131133
import static org.junit.jupiter.api.Assertions.assertFalse;
132134
import static org.junit.jupiter.api.Assertions.assertTrue;
133135

@@ -11887,6 +11889,109 @@ public Sql schema(CalciteAssert.SchemaSpec schemaSpec) {
1188711889
sql(sql).schema(CalciteAssert.SchemaSpec.JDBC_SCOTT).ok(expected);
1188811890
}
1188911891

11892+
/** Test case for
11893+
* <a href="https://issues.apache.org/jira/browse/CALCITE-7439">[CALCITE-7439]
11894+
* RelToSqlConverter emits ambiguous GROUP BY after LEFT JOIN USING with
11895+
* semi-join rewrite.</a>. */
11896+
@Test void testPostgresqlRoundTripDistinctLeftJoinInSubqueryWithSemiJoinRules() {
11897+
final String query = "WITH product_keys AS (\n"
11898+
+ " SELECT p.\"product_id\",\n"
11899+
+ " (SELECT MAX(p3.\"product_id\")\n"
11900+
+ " FROM \"foodmart\".\"product\" p3\n"
11901+
+ " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n"
11902+
+ " FROM \"foodmart\".\"product\" p\n"
11903+
+ ")\n"
11904+
+ "SELECT DISTINCT pk.\"product_id\"\n"
11905+
+ "FROM product_keys pk\n"
11906+
+ "LEFT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
11907+
+ "WHERE pk.\"product_id\" IN (\n"
11908+
+ " SELECT p4.\"product_id\"\n"
11909+
+ " FROM \"foodmart\".\"product\" p4\n"
11910+
+ ")";
11911+
11912+
final RuleSet rules =
11913+
RuleSets.ofList(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
11914+
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
11915+
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
11916+
CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE,
11917+
CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE,
11918+
CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE,
11919+
CoreRules.PROJECT_TO_SEMI_JOIN);
11920+
11921+
final String generated = sql(query).withPostgresql().optimize(rules, null).exec();
11922+
assertThat(generated, containsString("GROUP BY \"t2\".\"product_id\""));
11923+
}
11924+
11925+
/** Test case for
11926+
* <a href="https://issues.apache.org/jira/browse/CALCITE-7439">[CALCITE-7439]
11927+
* RelToSqlConverter should not emit ambiguous GROUP BY after RIGHT JOIN USING
11928+
* with semi-join rewrite.</a>. */
11929+
@Test void testPostgresqlRoundTripDistinctRightJoinInSubqueryWithSemiJoinRules() {
11930+
final String query = "WITH product_keys AS (\n"
11931+
+ " SELECT p.\"product_id\",\n"
11932+
+ " (SELECT MAX(p3.\"product_id\")\n"
11933+
+ " FROM \"foodmart\".\"product\" p3\n"
11934+
+ " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n"
11935+
+ " FROM \"foodmart\".\"product\" p\n"
11936+
+ ")\n"
11937+
+ "SELECT DISTINCT pk.\"product_id\"\n"
11938+
+ "FROM product_keys pk\n"
11939+
+ "RIGHT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
11940+
+ "WHERE pk.\"product_id\" IN (\n"
11941+
+ " SELECT p4.\"product_id\"\n"
11942+
+ " FROM \"foodmart\".\"product\" p4\n"
11943+
+ ")";
11944+
11945+
final RuleSet rules =
11946+
RuleSets.ofList(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
11947+
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
11948+
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
11949+
CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE,
11950+
CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE,
11951+
CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE,
11952+
CoreRules.PROJECT_TO_SEMI_JOIN);
11953+
11954+
final String generated = sql(query).withPostgresql().optimize(rules, null).exec();
11955+
assertThat(generated, containsString("GROUP BY "));
11956+
assertThat(generated, containsString(".\"product_id\""));
11957+
assertThat(generated, not(containsString("GROUP BY \"product_id\"")));
11958+
}
11959+
11960+
/** Test case for
11961+
* <a href="https://issues.apache.org/jira/browse/CALCITE-7439">[CALCITE-7439]
11962+
* RelToSqlConverter should not emit ambiguous GROUP BY after FULL JOIN USING
11963+
* with semi-join rewrite.</a>. */
11964+
@Test void testPostgresqlRoundTripDistinctFullJoinInSubqueryWithSemiJoinRules() {
11965+
final String query = "WITH product_keys AS (\n"
11966+
+ " SELECT p.\"product_id\",\n"
11967+
+ " (SELECT MAX(p3.\"product_id\")\n"
11968+
+ " FROM \"foodmart\".\"product\" p3\n"
11969+
+ " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n"
11970+
+ " FROM \"foodmart\".\"product\" p\n"
11971+
+ ")\n"
11972+
+ "SELECT DISTINCT pk.\"product_id\"\n"
11973+
+ "FROM product_keys pk\n"
11974+
+ "FULL JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
11975+
+ "WHERE pk.\"product_id\" IN (\n"
11976+
+ " SELECT p4.\"product_id\"\n"
11977+
+ " FROM \"foodmart\".\"product\" p4\n"
11978+
+ ")";
11979+
11980+
final RuleSet rules =
11981+
RuleSets.ofList(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
11982+
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
11983+
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
11984+
CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE,
11985+
CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE,
11986+
CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE,
11987+
CoreRules.PROJECT_TO_SEMI_JOIN);
11988+
11989+
final String generated = sql(query).withPostgresql().optimize(rules, null).exec();
11990+
assertThat(generated, containsString("GROUP BY "));
11991+
assertThat(generated, containsString(".\"product_id\""));
11992+
assertThat(generated, not(containsString("GROUP BY \"product_id\"")));
11993+
}
11994+
1189011995
@Test void testNotBetween() {
1189111996
Sql f = fixture().withConvertletTable(new SqlRexConvertletTable() {
1189211997
@Override public @Nullable SqlRexConvertlet get(SqlCall call) {

0 commit comments

Comments
 (0)