diff --git a/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java b/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java index 0404fa7bc51..a88e9d8a3c1 100644 --- a/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java +++ b/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java @@ -115,6 +115,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.SortedSet; import java.util.TreeSet; @@ -858,12 +859,21 @@ private List generateGroupList(Builder builder, + aggregate.getGroupSet() + ", just possibly a different order"; final List groupKeys = new ArrayList<>(); + final Join aggregateJoinInput = + aggregate.getInput() instanceof Join ? (Join) aggregate.getInput() : null; + final SqlJoin fromJoin = + builder.select.getFrom() instanceof SqlJoin ? (SqlJoin) builder.select.getFrom() : null; + final int leftFieldCount = aggregateJoinInput == null + ? -1 + : aggregateJoinInput.getLeft().getRowType().getFieldCount(); for (int key : groupList) { - final SqlNode field = builder.context.field(key); + SqlNode field = builder.context.field(key); + field = maybeQualifyJoinKey(field, key, fromJoin, leftFieldCount); groupKeys.add(field); } for (int key : sortedGroupList) { - final SqlNode field = builder.context.field(key); + SqlNode field = + maybeQualifyJoinKey(builder.context.field(key), key, fromJoin, leftFieldCount); addSelect(selectList, field, aggregate.getRowType()); } switch (aggregate.getGroupType()) { @@ -880,7 +890,8 @@ private List generateGroupList(Builder builder, final List rollupBits = Aggregate.Group.getRollup(aggregate.groupSets); final List rollupKeys = rollupBits .stream() - .map(bit -> builder.context.field(bit)) + .map(bit -> + maybeQualifyJoinKey(builder.context.field(bit), bit, fromJoin, leftFieldCount)) .collect(Collectors.toList()); return ImmutableList.of( SqlStdOperatorTable.ROLLUP.createCall(SqlParserPos.ZERO, rollupKeys)); @@ -905,6 +916,115 @@ private List generateGroupList(Builder builder, } } + private SqlNode maybeQualifyJoinKey(SqlNode field, int key, + @Nullable SqlJoin fromJoin, int leftFieldCount) { + if (!isSimpleIdentifier(field) || fromJoin == null) { + return field; + } + + final String fieldName = ((SqlIdentifier) field).getSimple(); + if (leftFieldCount >= 0) { + final SqlNode side = key < leftFieldCount ? fromJoin.getLeft() : fromJoin.getRight(); + return qualifyJoinField(SqlValidatorUtil.alias(side), fieldName, field); + } + return maybeQualifyJoinKeyWithoutInputJoin(field, fromJoin, fieldName); + } + + private static boolean isSimpleIdentifier(SqlNode node) { + return node instanceof SqlIdentifier + && ((SqlIdentifier) node).names.size() == 1; + } + + private SqlNode maybeQualifyJoinKeyWithoutInputJoin(SqlNode field, + SqlJoin fromJoin, String fieldName) { + final String leftAlias = SqlValidatorUtil.alias(fromJoin.getLeft()); + final String rightAlias = SqlValidatorUtil.alias(fromJoin.getRight()); + if (!isMergedJoinKey(fromJoin, leftAlias, rightAlias, fieldName)) { + return field; + } + switch (fromJoin.getJoinType()) { + case RIGHT: + return qualifyJoinField(rightAlias, fieldName, field); + case FULL: + if (leftAlias != null && rightAlias != null) { + return SqlStdOperatorTable.COALESCE.createCall(POS, + new SqlIdentifier(ImmutableList.of(leftAlias, fieldName), POS), + new SqlIdentifier(ImmutableList.of(rightAlias, fieldName), POS)); + } + return qualifyJoinField(leftAlias != null ? leftAlias : rightAlias, fieldName, field); + case LEFT: + case LEFT_SEMI_JOIN: + case LEFT_ANTI_JOIN: + case INNER: + case CROSS: + case COMMA: + case ASOF: + case LEFT_ASOF: + default: + return qualifyJoinField(leftAlias, fieldName, field); + } + } + + private static boolean isMergedJoinKey(SqlJoin fromJoin, + @Nullable String leftAlias, @Nullable String rightAlias, String fieldName) { + final @Nullable SqlNode condition = fromJoin.getCondition(); + if (fromJoin.getConditionType() == JoinConditionType.USING) { + if (!(condition instanceof SqlNodeList)) { + return false; + } + for (SqlNode node : ((SqlNodeList) condition).getList()) { + if (node != null + && isSimpleIdentifier(node) + && fieldName.equals(((SqlIdentifier) node).getSimple())) { + return true; + } + } + return false; + } + return isMergedJoinKeyCondition(condition, leftAlias, rightAlias, fieldName); + } + + private static boolean isMergedJoinKeyCondition(@Nullable SqlNode condition, + @Nullable String leftAlias, @Nullable String rightAlias, String fieldName) { + if (!(condition instanceof SqlCall)) { + return false; + } + final SqlCall call = (SqlCall) condition; + switch (call.getKind()) { + case AND: + return call.getOperandList().stream() + .filter(Objects::nonNull) + .anyMatch(node -> + isMergedJoinKeyCondition(node, leftAlias, rightAlias, fieldName)); + case EQUALS: + return isQualifiedJoinField(call.operand(0), leftAlias, fieldName) + && isQualifiedJoinField(call.operand(1), rightAlias, fieldName) + || isQualifiedJoinField(call.operand(0), rightAlias, fieldName) + && isQualifiedJoinField(call.operand(1), leftAlias, fieldName); + default: + return false; + } + } + + private static boolean isQualifiedJoinField(@Nullable SqlNode node, + @Nullable String alias, String fieldName) { + if (!(node instanceof SqlIdentifier) || alias == null) { + return false; + } + final SqlIdentifier id = (SqlIdentifier) node; + return id.names.size() == 2 + && alias.equals(id.names.get(0)) + && fieldName.equals(id.names.get(1)); + } + + private static SqlNode qualifyJoinField(@Nullable String alias, + String fieldName, SqlNode fallback) { + if (alias == null) { + return fallback; + } + return new SqlIdentifier(ImmutableList.of(alias, fieldName), POS); + } + private static SqlNode groupItem(List groupKeys, ImmutableBitSet groupSet, ImmutableBitSet wholeGroupSet) { final List nodes = groupSet.asList().stream() diff --git a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java index 41054312202..6601990a6a5 100644 --- a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java +++ b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java @@ -125,9 +125,12 @@ import static org.apache.calcite.test.Matchers.isLinux; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasToString; +import static org.hamcrest.Matchers.matchesPattern; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -147,6 +150,26 @@ private Sql sql(String sql) { return fixture().withSql(sql); } + private void assertPostgresqlSqlValid(String sql) { + try { + final SchemaPlus rootSchema = Frameworks.createRootSchema(true); + final SchemaPlus defaultSchema = + CalciteAssert.addSchema(rootSchema, CalciteAssert.SchemaSpec.JDBC_FOODMART); + final Planner planner = + getPlanner(null, + PostgresqlSqlDialect.DEFAULT.configureParser(SqlParser.config()), + defaultSchema, + SqlToRelConverter.config().withTrimUnusedFields(false), + ImmutableSet.of(), + DatabaseProduct.POSTGRESQL.getDialect().getTypeSystem(), + StandardConvertletTable.INSTANCE); + final SqlNode parsed = planner.parse(sql); + planner.validate(parsed); + } catch (Exception e) { + throw TestUtil.rethrow(e); + } + } + /** Initiates a test case with a given {@link RelNode} supplier. */ private Sql relFn(Function relFn) { return fixture() @@ -11887,6 +11910,178 @@ public Sql schema(CalciteAssert.SchemaSpec schemaSpec) { sql(sql).schema(CalciteAssert.SchemaSpec.JDBC_SCOTT).ok(expected); } + /** Test case for + * [CALCITE-7439] + * RelToSqlConverter emits ambiguous GROUP BY after LEFT JOIN USING with + * semi-join rewrite.. */ + @Test void testPostgresqlRoundTripDistinctLeftJoinInSubqueryWithSemiJoinRules() { + final String query = "WITH product_keys AS (\n" + + " SELECT p.\"product_id\",\n" + + " (SELECT MAX(p3.\"product_id\")\n" + + " FROM \"foodmart\".\"product\" p3\n" + + " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n" + + " FROM \"foodmart\".\"product\" p\n" + + ")\n" + + "SELECT DISTINCT pk.\"product_id\"\n" + + "FROM product_keys pk\n" + + "LEFT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n" + + "WHERE pk.\"product_id\" IN (\n" + + " SELECT p4.\"product_id\"\n" + + " FROM \"foodmart\".\"product\" p4\n" + + ")"; + + final RuleSet rules = + RuleSets.ofList(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE, + CoreRules.FILTER_SUB_QUERY_TO_CORRELATE, + CoreRules.JOIN_SUB_QUERY_TO_CORRELATE, + CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE, + CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE, + CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE, + CoreRules.PROJECT_TO_SEMI_JOIN); + + final String generated = sql(query).withPostgresql().optimize(rules, null).exec(); + assertThat(generated, + matchesPattern("(?s).*GROUP BY\\s+\"[^\"]+\"\\.\"product_id\".*")); + assertThat(generated, not(containsString("GROUP BY \"product_id\""))); + assertPostgresqlSqlValid(generated); + } + + @Test void testPostgresqlRoundTripDistinctLeftJoinUsingTwoKeysWithSemiJoinRules() { + final String query = "WITH product_keys AS (\n" + + " SELECT p.\"product_id\",\n" + + " p.\"net_weight\",\n" + + " (SELECT MAX(p3.\"product_id\")\n" + + " FROM \"foodmart\".\"product\" p3\n" + + " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n" + + " FROM \"foodmart\".\"product\" p\n" + + ")\n" + + "SELECT DISTINCT pk.\"product_id\", pk.\"net_weight\"\n" + + "FROM product_keys pk\n" + + "LEFT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\", \"net_weight\")\n" + + "WHERE pk.\"product_id\" IN (\n" + + " SELECT p4.\"product_id\"\n" + + " FROM \"foodmart\".\"product\" p4\n" + + ")"; + + final RuleSet rules = + RuleSets.ofList(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE, + CoreRules.FILTER_SUB_QUERY_TO_CORRELATE, + CoreRules.JOIN_SUB_QUERY_TO_CORRELATE, + CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE, + CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE, + CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE, + CoreRules.PROJECT_TO_SEMI_JOIN); + + final String generated = sql(query).withPostgresql().optimize(rules, null).exec(); + assertThat(generated, + matchesPattern("(?s).*GROUP BY\\s+\"[^\"]+\"\\.\"product_id\",\\s*" + + "\"[^\"]+\"\\.\"net_weight\".*")); + assertThat(generated, + not(containsString("GROUP BY \"product_id\", \"net_weight\""))); + assertPostgresqlSqlValid(generated); + } + + /** Test case for + * [CALCITE-7439] + * RelToSqlConverter should not emit ambiguous GROUP BY after RIGHT JOIN USING + * with semi-join rewrite.. */ + @Test void testPostgresqlRoundTripDistinctRightJoinInSubqueryWithSemiJoinRules() { + final String query = "WITH product_keys AS (\n" + + " SELECT p.\"product_id\",\n" + + " (SELECT MAX(p3.\"product_id\")\n" + + " FROM \"foodmart\".\"product\" p3\n" + + " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n" + + " FROM \"foodmart\".\"product\" p\n" + + ")\n" + + "SELECT DISTINCT pk.\"product_id\"\n" + + "FROM product_keys pk\n" + + "RIGHT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n" + + "WHERE pk.\"product_id\" IN (\n" + + " SELECT p4.\"product_id\"\n" + + " FROM \"foodmart\".\"product\" p4\n" + + ")"; + + final RuleSet rules = + RuleSets.ofList(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE, + CoreRules.FILTER_SUB_QUERY_TO_CORRELATE, + CoreRules.JOIN_SUB_QUERY_TO_CORRELATE, + CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE, + CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE, + CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE, + CoreRules.PROJECT_TO_SEMI_JOIN); + + final String generated = sql(query).withPostgresql().optimize(rules, null).exec(); + assertThat(generated, + matchesPattern("(?s).*GROUP BY\\s+\"[^\"]+\"\\.\"product_id\".*")); + assertThat(generated, not(containsString("GROUP BY \"product_id\""))); + assertPostgresqlSqlValid(generated); + } + + /** Test case for + * [CALCITE-7439] + * RelToSqlConverter should not emit ambiguous GROUP BY after FULL JOIN USING + * with semi-join rewrite.. */ + @Test void testPostgresqlRoundTripDistinctFullJoinInSubqueryWithSemiJoinRules() { + final String query = "WITH product_keys AS (\n" + + " SELECT p.\"product_id\",\n" + + " (SELECT MAX(p3.\"product_id\")\n" + + " FROM \"foodmart\".\"product\" p3\n" + + " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n" + + " FROM \"foodmart\".\"product\" p\n" + + ")\n" + + "SELECT DISTINCT pk.\"product_id\"\n" + + "FROM product_keys pk\n" + + "FULL JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n" + + "WHERE pk.\"product_id\" IN (\n" + + " SELECT p4.\"product_id\"\n" + + " FROM \"foodmart\".\"product\" p4\n" + + ")"; + + final RuleSet rules = + RuleSets.ofList(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE, + CoreRules.FILTER_SUB_QUERY_TO_CORRELATE, + CoreRules.JOIN_SUB_QUERY_TO_CORRELATE, + CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE, + CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE, + CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE, + CoreRules.PROJECT_TO_SEMI_JOIN); + + final String generated = sql(query).withPostgresql().optimize(rules, null).exec(); + assertThat(generated, containsString("GROUP BY COALESCE(")); + assertThat(generated, + matchesPattern("(?s).*GROUP BY\\s+COALESCE\\(\"[^\"]+\"\\.\"product_id\",\\s*" + + "\"[^\"]+\"\\.\"product_id\"\\).*")); + assertThat(generated, not(containsString("GROUP BY \"product_id\""))); + assertPostgresqlSqlValid(generated); + } + + @Test void testPostgresqlRoundTripRollupJoinUsingQualifiesGroupKey() { + final String query = "SELECT \"product_id\", COUNT(*)\n" + + "FROM \"foodmart\".\"product\" p1\n" + + "LEFT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n" + + "GROUP BY ROLLUP(\"product_id\")"; + + final String generated = sql(query).withPostgresql().exec(); + assertThat(generated, + matchesPattern("(?s).*GROUP BY\\s+ROLLUP\\(\"[^\"]+\"\\.\"product_id\"\\).*")); + assertThat(generated, not(containsString("GROUP BY ROLLUP(\"product_id\")"))); + assertPostgresqlSqlValid(generated); + } + + @Test void testPostgresqlRoundTripSingletonCubeJoinUsingQualifiesGroupKey() { + final String query = "SELECT \"product_id\", COUNT(*)\n" + + "FROM \"foodmart\".\"product\" p1\n" + + "LEFT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n" + + "GROUP BY CUBE(\"product_id\")"; + + final String generated = sql(query).withPostgresql().exec(); + assertThat(generated, + matchesPattern("(?s).*GROUP BY\\s+(?:CUBE|ROLLUP)\\(\"[^\"]+\"\\.\"product_id\"\\).*")); + assertThat(generated, not(containsString("GROUP BY CUBE(\"product_id\")"))); + assertThat(generated, not(containsString("GROUP BY ROLLUP(\"product_id\")"))); + assertPostgresqlSqlValid(generated); + } + @Test void testNotBetween() { Sql f = fixture().withConvertletTable(new SqlRexConvertletTable() { @Override public @Nullable SqlRexConvertlet get(SqlCall call) {