Skip to content

Commit 0558ef1

Browse files
committed
[CALCITE-7439] Qualify GROUP BY keys for DISTINCT over joins
1 parent c628e68 commit 0558ef1

File tree

2 files changed

+198
-3
lines changed

2 files changed

+198
-3
lines changed

core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -858,12 +858,21 @@ private List<SqlNode> generateGroupList(Builder builder,
858858
+ aggregate.getGroupSet() + ", just possibly a different order";
859859

860860
final List<SqlNode> groupKeys = new ArrayList<>();
861+
final Join aggregateJoinInput =
862+
aggregate.getInput() instanceof Join ? (Join) aggregate.getInput() : null;
863+
final SqlJoin fromJoin =
864+
builder.select.getFrom() instanceof SqlJoin ? (SqlJoin) builder.select.getFrom() : null;
865+
final int leftFieldCount = aggregateJoinInput == null
866+
? -1
867+
: aggregateJoinInput.getLeft().getRowType().getFieldCount();
861868
for (int key : groupList) {
862-
final SqlNode field = builder.context.field(key);
869+
SqlNode field = builder.context.field(key);
870+
field = maybeQualifyJoinKey(field, key, fromJoin, leftFieldCount);
863871
groupKeys.add(field);
864872
}
865873
for (int key : sortedGroupList) {
866-
final SqlNode field = builder.context.field(key);
874+
SqlNode field =
875+
maybeQualifyJoinKey(builder.context.field(key), key, fromJoin, leftFieldCount);
867876
addSelect(selectList, field, aggregate.getRowType());
868877
}
869878
switch (aggregate.getGroupType()) {
@@ -880,7 +889,7 @@ private List<SqlNode> generateGroupList(Builder builder,
880889
final List<Integer> rollupBits = Aggregate.Group.getRollup(aggregate.groupSets);
881890
final List<SqlNode> rollupKeys = rollupBits
882891
.stream()
883-
.map(bit -> builder.context.field(bit))
892+
.map(bit -> maybeQualifyJoinKey(builder.context.field(bit), bit, fromJoin, leftFieldCount))
884893
.collect(Collectors.toList());
885894
return ImmutableList.of(
886895
SqlStdOperatorTable.ROLLUP.createCall(SqlParserPos.ZERO, rollupKeys));
@@ -905,6 +914,56 @@ private List<SqlNode> generateGroupList(Builder builder,
905914
}
906915
}
907916

917+
private SqlNode maybeQualifyJoinKey(SqlNode field, int key,
918+
@Nullable SqlJoin fromJoin, int leftFieldCount) {
919+
if (!(field instanceof SqlIdentifier)
920+
|| ((SqlIdentifier) field).names.size() != 1
921+
|| fromJoin == null) {
922+
return field;
923+
}
924+
925+
final String fieldName = ((SqlIdentifier) field).getSimple();
926+
final String leftAlias = SqlValidatorUtil.alias(fromJoin.getLeft());
927+
final String rightAlias = SqlValidatorUtil.alias(fromJoin.getRight());
928+
if (leftFieldCount < 0) {
929+
if (key != 0) {
930+
return field;
931+
}
932+
switch (fromJoin.getJoinType()) {
933+
case RIGHT:
934+
return qualifyJoinField(rightAlias, fieldName, field);
935+
case FULL:
936+
if (leftAlias != null && rightAlias != null) {
937+
return SqlStdOperatorTable.COALESCE.createCall(POS,
938+
new SqlIdentifier(ImmutableList.of(leftAlias, fieldName), POS),
939+
new SqlIdentifier(ImmutableList.of(rightAlias, fieldName), POS));
940+
}
941+
return qualifyJoinField(leftAlias != null ? leftAlias : rightAlias, fieldName, field);
942+
case LEFT:
943+
case LEFT_SEMI_JOIN:
944+
case LEFT_ANTI_JOIN:
945+
case INNER:
946+
case CROSS:
947+
case COMMA:
948+
case ASOF:
949+
case LEFT_ASOF:
950+
default:
951+
return qualifyJoinField(leftAlias, fieldName, field);
952+
}
953+
} else {
954+
final SqlNode side = key < leftFieldCount ? fromJoin.getLeft() : fromJoin.getRight();
955+
final String sideAlias = SqlValidatorUtil.alias(side);
956+
return qualifyJoinField(sideAlias, fieldName, field);
957+
}
958+
}
959+
960+
private static SqlNode qualifyJoinField(@Nullable String alias, String fieldName, SqlNode fallback) {
961+
if (alias == null) {
962+
return fallback;
963+
}
964+
return new SqlIdentifier(ImmutableList.of(alias, fieldName), POS);
965+
}
966+
908967
private static SqlNode groupItem(List<SqlNode> groupKeys,
909968
ImmutableBitSet groupSet, ImmutableBitSet wholeGroupSet) {
910969
final List<SqlNode> nodes = groupSet.asList().stream()

core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,12 @@
125125
import static org.apache.calcite.test.Matchers.isLinux;
126126

127127
import static org.hamcrest.CoreMatchers.is;
128+
import static org.hamcrest.CoreMatchers.not;
128129
import static org.hamcrest.CoreMatchers.notNullValue;
129130
import static org.hamcrest.MatcherAssert.assertThat;
131+
import static org.hamcrest.Matchers.containsString;
130132
import static org.hamcrest.Matchers.hasToString;
133+
import static org.hamcrest.Matchers.matchesPattern;
131134
import static org.junit.jupiter.api.Assertions.assertFalse;
132135
import static org.junit.jupiter.api.Assertions.assertTrue;
133136

@@ -11887,6 +11890,139 @@ public Sql schema(CalciteAssert.SchemaSpec schemaSpec) {
1188711890
sql(sql).schema(CalciteAssert.SchemaSpec.JDBC_SCOTT).ok(expected);
1188811891
}
1188911892

11893+
/** Test case for
11894+
* <a href="https://issues.apache.org/jira/browse/CALCITE-7439">[CALCITE-7439]
11895+
* RelToSqlConverter emits ambiguous GROUP BY after LEFT JOIN USING with
11896+
* semi-join rewrite.</a>. */
11897+
@Test void testPostgresqlRoundTripDistinctLeftJoinInSubqueryWithSemiJoinRules() {
11898+
final String query = "WITH product_keys AS (\n"
11899+
+ " SELECT p.\"product_id\",\n"
11900+
+ " (SELECT MAX(p3.\"product_id\")\n"
11901+
+ " FROM \"foodmart\".\"product\" p3\n"
11902+
+ " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n"
11903+
+ " FROM \"foodmart\".\"product\" p\n"
11904+
+ ")\n"
11905+
+ "SELECT DISTINCT pk.\"product_id\"\n"
11906+
+ "FROM product_keys pk\n"
11907+
+ "LEFT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
11908+
+ "WHERE pk.\"product_id\" IN (\n"
11909+
+ " SELECT p4.\"product_id\"\n"
11910+
+ " FROM \"foodmart\".\"product\" p4\n"
11911+
+ ")";
11912+
11913+
final RuleSet rules =
11914+
RuleSets.ofList(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
11915+
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
11916+
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
11917+
CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE,
11918+
CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE,
11919+
CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE,
11920+
CoreRules.PROJECT_TO_SEMI_JOIN);
11921+
11922+
final String generated = sql(query).withPostgresql().optimize(rules, null).exec();
11923+
assertThat(generated, containsString("GROUP BY \"t2\".\"product_id\""));
11924+
}
11925+
11926+
/** Test case for
11927+
* <a href="https://issues.apache.org/jira/browse/CALCITE-7439">[CALCITE-7439]
11928+
* RelToSqlConverter should not emit ambiguous GROUP BY after RIGHT JOIN USING
11929+
* with semi-join rewrite.</a>. */
11930+
@Test void testPostgresqlRoundTripDistinctRightJoinInSubqueryWithSemiJoinRules() {
11931+
final String query = "WITH product_keys AS (\n"
11932+
+ " SELECT p.\"product_id\",\n"
11933+
+ " (SELECT MAX(p3.\"product_id\")\n"
11934+
+ " FROM \"foodmart\".\"product\" p3\n"
11935+
+ " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n"
11936+
+ " FROM \"foodmart\".\"product\" p\n"
11937+
+ ")\n"
11938+
+ "SELECT DISTINCT pk.\"product_id\"\n"
11939+
+ "FROM product_keys pk\n"
11940+
+ "RIGHT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
11941+
+ "WHERE pk.\"product_id\" IN (\n"
11942+
+ " SELECT p4.\"product_id\"\n"
11943+
+ " FROM \"foodmart\".\"product\" p4\n"
11944+
+ ")";
11945+
11946+
final RuleSet rules =
11947+
RuleSets.ofList(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
11948+
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
11949+
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
11950+
CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE,
11951+
CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE,
11952+
CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE,
11953+
CoreRules.PROJECT_TO_SEMI_JOIN);
11954+
11955+
final String generated = sql(query).withPostgresql().optimize(rules, null).exec();
11956+
assertThat(generated, containsString("GROUP BY "));
11957+
assertThat(generated, containsString(".\"product_id\""));
11958+
assertThat(generated,
11959+
matchesPattern("(?s).*GROUP BY\\s+\"[^\"]+\"\\.\"product_id\".*"));
11960+
assertThat(generated, not(containsString("GROUP BY \"product_id\"")));
11961+
}
11962+
11963+
/** Test case for
11964+
* <a href="https://issues.apache.org/jira/browse/CALCITE-7439">[CALCITE-7439]
11965+
* RelToSqlConverter should not emit ambiguous GROUP BY after FULL JOIN USING
11966+
* with semi-join rewrite.</a>. */
11967+
@Test void testPostgresqlRoundTripDistinctFullJoinInSubqueryWithSemiJoinRules() {
11968+
final String query = "WITH product_keys AS (\n"
11969+
+ " SELECT p.\"product_id\",\n"
11970+
+ " (SELECT MAX(p3.\"product_id\")\n"
11971+
+ " FROM \"foodmart\".\"product\" p3\n"
11972+
+ " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n"
11973+
+ " FROM \"foodmart\".\"product\" p\n"
11974+
+ ")\n"
11975+
+ "SELECT DISTINCT pk.\"product_id\"\n"
11976+
+ "FROM product_keys pk\n"
11977+
+ "FULL JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
11978+
+ "WHERE pk.\"product_id\" IN (\n"
11979+
+ " SELECT p4.\"product_id\"\n"
11980+
+ " FROM \"foodmart\".\"product\" p4\n"
11981+
+ ")";
11982+
11983+
final RuleSet rules =
11984+
RuleSets.ofList(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
11985+
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
11986+
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
11987+
CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE,
11988+
CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE,
11989+
CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE,
11990+
CoreRules.PROJECT_TO_SEMI_JOIN);
11991+
11992+
final String generated = sql(query).withPostgresql().optimize(rules, null).exec();
11993+
assertThat(generated, containsString("GROUP BY "));
11994+
assertThat(generated, containsString("GROUP BY COALESCE("));
11995+
assertThat(generated,
11996+
matchesPattern("(?s).*GROUP BY\\s+COALESCE\\(\"[^\"]+\"\\.\"product_id\",\\s*"
11997+
+ "\"[^\"]+\"\\.\"product_id\"\\).*"));
11998+
assertThat(generated, not(containsString("GROUP BY \"product_id\"")));
11999+
}
12000+
12001+
@Test void testPostgresqlRoundTripRollupJoinUsingQualifiesGroupKey() {
12002+
final String query = "SELECT \"product_id\", COUNT(*)\n"
12003+
+ "FROM \"foodmart\".\"product\" p1\n"
12004+
+ "LEFT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
12005+
+ "GROUP BY ROLLUP(\"product_id\")";
12006+
12007+
final String generated = sql(query).withPostgresql().exec();
12008+
assertThat(generated,
12009+
matchesPattern("(?s).*GROUP BY\\s+ROLLUP\\(\"[^\"]+\"\\.\"product_id\"\\).*"));
12010+
assertThat(generated, not(containsString("GROUP BY ROLLUP(\"product_id\")")));
12011+
}
12012+
12013+
@Test void testPostgresqlRoundTripSingletonCubeJoinUsingQualifiesGroupKey() {
12014+
final String query = "SELECT \"product_id\", COUNT(*)\n"
12015+
+ "FROM \"foodmart\".\"product\" p1\n"
12016+
+ "LEFT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
12017+
+ "GROUP BY CUBE(\"product_id\")";
12018+
12019+
final String generated = sql(query).withPostgresql().exec();
12020+
assertThat(generated,
12021+
matchesPattern("(?s).*GROUP BY\\s+(?:CUBE|ROLLUP)\\(\"[^\"]+\"\\.\"product_id\"\\).*"));
12022+
assertThat(generated, not(containsString("GROUP BY CUBE(\"product_id\")")));
12023+
assertThat(generated, not(containsString("GROUP BY ROLLUP(\"product_id\")")));
12024+
}
12025+
1189012026
@Test void testNotBetween() {
1189112027
Sql f = fixture().withConvertletTable(new SqlRexConvertletTable() {
1189212028
@Override public @Nullable SqlRexConvertlet get(SqlCall call) {

0 commit comments

Comments
 (0)