Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
Expand Down Expand Up @@ -858,12 +859,21 @@ private List<SqlNode> generateGroupList(Builder builder,
+ aggregate.getGroupSet() + ", just possibly a different order";

final List<SqlNode> groupKeys = new ArrayList<>();
final Join aggregateJoinInput =
aggregate.getInput() instanceof Join ? (Join) aggregate.getInput() : null;
final SqlJoin fromJoin =
builder.select.getFrom() instanceof SqlJoin ? (SqlJoin) builder.select.getFrom() : null;
final int leftFieldCount = aggregateJoinInput == null
? -1
: aggregateJoinInput.getLeft().getRowType().getFieldCount();
for (int key : groupList) {
final SqlNode field = builder.context.field(key);
SqlNode field = builder.context.field(key);
field = maybeQualifyJoinKey(field, key, fromJoin, leftFieldCount);
groupKeys.add(field);
}
for (int key : sortedGroupList) {
final SqlNode field = builder.context.field(key);
SqlNode field =
maybeQualifyJoinKey(builder.context.field(key), key, fromJoin, leftFieldCount);
addSelect(selectList, field, aggregate.getRowType());
}
switch (aggregate.getGroupType()) {
Expand All @@ -880,7 +890,8 @@ private List<SqlNode> generateGroupList(Builder builder,
final List<Integer> rollupBits = Aggregate.Group.getRollup(aggregate.groupSets);
final List<SqlNode> rollupKeys = rollupBits
.stream()
.map(bit -> builder.context.field(bit))
.map(bit ->
maybeQualifyJoinKey(builder.context.field(bit), bit, fromJoin, leftFieldCount))
.collect(Collectors.toList());
return ImmutableList.of(
SqlStdOperatorTable.ROLLUP.createCall(SqlParserPos.ZERO, rollupKeys));
Expand All @@ -905,6 +916,115 @@ private List<SqlNode> generateGroupList(Builder builder,
}
}

private SqlNode maybeQualifyJoinKey(SqlNode field, int key,
@Nullable SqlJoin fromJoin, int leftFieldCount) {
if (!isSimpleIdentifier(field) || fromJoin == null) {
return field;
}

final String fieldName = ((SqlIdentifier) field).getSimple();
if (leftFieldCount >= 0) {
final SqlNode side = key < leftFieldCount ? fromJoin.getLeft() : fromJoin.getRight();
return qualifyJoinField(SqlValidatorUtil.alias(side), fieldName, field);
}
return maybeQualifyJoinKeyWithoutInputJoin(field, fromJoin, fieldName);
}

private static boolean isSimpleIdentifier(SqlNode node) {
return node instanceof SqlIdentifier
&& ((SqlIdentifier) node).names.size() == 1;
}

private SqlNode maybeQualifyJoinKeyWithoutInputJoin(SqlNode field,
SqlJoin fromJoin, String fieldName) {
final String leftAlias = SqlValidatorUtil.alias(fromJoin.getLeft());
final String rightAlias = SqlValidatorUtil.alias(fromJoin.getRight());
if (!isMergedJoinKey(fromJoin, leftAlias, rightAlias, fieldName)) {
return field;
}
switch (fromJoin.getJoinType()) {
case RIGHT:
return qualifyJoinField(rightAlias, fieldName, field);
case FULL:
if (leftAlias != null && rightAlias != null) {
return SqlStdOperatorTable.COALESCE.createCall(POS,
new SqlIdentifier(ImmutableList.of(leftAlias, fieldName), POS),
new SqlIdentifier(ImmutableList.of(rightAlias, fieldName), POS));
}
return qualifyJoinField(leftAlias != null ? leftAlias : rightAlias, fieldName, field);
case LEFT:
case LEFT_SEMI_JOIN:
case LEFT_ANTI_JOIN:
case INNER:
case CROSS:
case COMMA:
case ASOF:
case LEFT_ASOF:
default:
return qualifyJoinField(leftAlias, fieldName, field);
}
}

private static boolean isMergedJoinKey(SqlJoin fromJoin,
@Nullable String leftAlias, @Nullable String rightAlias, String fieldName) {
final @Nullable SqlNode condition = fromJoin.getCondition();
if (fromJoin.getConditionType() == JoinConditionType.USING) {
if (!(condition instanceof SqlNodeList)) {
return false;
}
for (SqlNode node : ((SqlNodeList) condition).getList()) {
if (node != null
&& isSimpleIdentifier(node)
&& fieldName.equals(((SqlIdentifier) node).getSimple())) {
return true;
}
}
return false;
}
return isMergedJoinKeyCondition(condition, leftAlias, rightAlias, fieldName);
}

private static boolean isMergedJoinKeyCondition(@Nullable SqlNode condition,
@Nullable String leftAlias, @Nullable String rightAlias, String fieldName) {
if (!(condition instanceof SqlCall)) {
return false;
}
final SqlCall call = (SqlCall) condition;
switch (call.getKind()) {
case AND:
return call.getOperandList().stream()
.filter(Objects::nonNull)
.anyMatch(node ->
isMergedJoinKeyCondition(node, leftAlias, rightAlias, fieldName));
case EQUALS:
return isQualifiedJoinField(call.operand(0), leftAlias, fieldName)
&& isQualifiedJoinField(call.operand(1), rightAlias, fieldName)
|| isQualifiedJoinField(call.operand(0), rightAlias, fieldName)
&& isQualifiedJoinField(call.operand(1), leftAlias, fieldName);
default:
return false;
}
}

private static boolean isQualifiedJoinField(@Nullable SqlNode node,
@Nullable String alias, String fieldName) {
if (!(node instanceof SqlIdentifier) || alias == null) {
return false;
}
final SqlIdentifier id = (SqlIdentifier) node;
return id.names.size() == 2
&& alias.equals(id.names.get(0))
&& fieldName.equals(id.names.get(1));
}

private static SqlNode qualifyJoinField(@Nullable String alias,
String fieldName, SqlNode fallback) {
if (alias == null) {
return fallback;
}
return new SqlIdentifier(ImmutableList.of(alias, fieldName), POS);
}

private static SqlNode groupItem(List<SqlNode> groupKeys,
ImmutableBitSet groupSet, ImmutableBitSet wholeGroupSet) {
final List<SqlNode> nodes = groupSet.asList().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,12 @@
import static org.apache.calcite.test.Matchers.isLinux;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasToString;
import static org.hamcrest.Matchers.matchesPattern;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

Expand All @@ -147,6 +150,26 @@ private Sql sql(String sql) {
return fixture().withSql(sql);
}

private void assertPostgresqlSqlValid(String sql) {
try {
final SchemaPlus rootSchema = Frameworks.createRootSchema(true);
final SchemaPlus defaultSchema =
CalciteAssert.addSchema(rootSchema, CalciteAssert.SchemaSpec.JDBC_FOODMART);
final Planner planner =
getPlanner(null,
PostgresqlSqlDialect.DEFAULT.configureParser(SqlParser.config()),
defaultSchema,
SqlToRelConverter.config().withTrimUnusedFields(false),
ImmutableSet.of(),
DatabaseProduct.POSTGRESQL.getDialect().getTypeSystem(),
StandardConvertletTable.INSTANCE);
final SqlNode parsed = planner.parse(sql);
planner.validate(parsed);
} catch (Exception e) {
throw TestUtil.rethrow(e);
}
}

/** Initiates a test case with a given {@link RelNode} supplier. */
private Sql relFn(Function<RelBuilder, RelNode> relFn) {
return fixture()
Expand Down Expand Up @@ -11887,6 +11910,178 @@ public Sql schema(CalciteAssert.SchemaSpec schemaSpec) {
sql(sql).schema(CalciteAssert.SchemaSpec.JDBC_SCOTT).ok(expected);
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-7439">[CALCITE-7439]
* RelToSqlConverter emits ambiguous GROUP BY after LEFT JOIN USING with
* semi-join rewrite.</a>. */
@Test void testPostgresqlRoundTripDistinctLeftJoinInSubqueryWithSemiJoinRules() {
final String query = "WITH product_keys AS (\n"
+ " SELECT p.\"product_id\",\n"
+ " (SELECT MAX(p3.\"product_id\")\n"
+ " FROM \"foodmart\".\"product\" p3\n"
+ " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n"
+ " FROM \"foodmart\".\"product\" p\n"
+ ")\n"
+ "SELECT DISTINCT pk.\"product_id\"\n"
+ "FROM product_keys pk\n"
+ "LEFT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
+ "WHERE pk.\"product_id\" IN (\n"
+ " SELECT p4.\"product_id\"\n"
+ " FROM \"foodmart\".\"product\" p4\n"
+ ")";

final RuleSet rules =
RuleSets.ofList(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE,
CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE,
CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE,
CoreRules.PROJECT_TO_SEMI_JOIN);

final String generated = sql(query).withPostgresql().optimize(rules, null).exec();
assertThat(generated,
matchesPattern("(?s).*GROUP BY\\s+\"[^\"]+\"\\.\"product_id\".*"));
assertThat(generated, not(containsString("GROUP BY \"product_id\"")));
assertPostgresqlSqlValid(generated);
}

@Test void testPostgresqlRoundTripDistinctLeftJoinUsingTwoKeysWithSemiJoinRules() {
final String query = "WITH product_keys AS (\n"
+ " SELECT p.\"product_id\",\n"
+ " p.\"net_weight\",\n"
+ " (SELECT MAX(p3.\"product_id\")\n"
+ " FROM \"foodmart\".\"product\" p3\n"
+ " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n"
+ " FROM \"foodmart\".\"product\" p\n"
+ ")\n"
+ "SELECT DISTINCT pk.\"product_id\", pk.\"net_weight\"\n"
+ "FROM product_keys pk\n"
+ "LEFT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\", \"net_weight\")\n"
+ "WHERE pk.\"product_id\" IN (\n"
+ " SELECT p4.\"product_id\"\n"
+ " FROM \"foodmart\".\"product\" p4\n"
+ ")";

final RuleSet rules =
RuleSets.ofList(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE,
CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE,
CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE,
CoreRules.PROJECT_TO_SEMI_JOIN);

final String generated = sql(query).withPostgresql().optimize(rules, null).exec();
assertThat(generated,
matchesPattern("(?s).*GROUP BY\\s+\"[^\"]+\"\\.\"product_id\",\\s*"
+ "\"[^\"]+\"\\.\"net_weight\".*"));
assertThat(generated,
not(containsString("GROUP BY \"product_id\", \"net_weight\"")));
assertPostgresqlSqlValid(generated);
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-7439">[CALCITE-7439]
* RelToSqlConverter should not emit ambiguous GROUP BY after RIGHT JOIN USING
* with semi-join rewrite.</a>. */
@Test void testPostgresqlRoundTripDistinctRightJoinInSubqueryWithSemiJoinRules() {
final String query = "WITH product_keys AS (\n"
+ " SELECT p.\"product_id\",\n"
+ " (SELECT MAX(p3.\"product_id\")\n"
+ " FROM \"foodmart\".\"product\" p3\n"
+ " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n"
+ " FROM \"foodmart\".\"product\" p\n"
+ ")\n"
+ "SELECT DISTINCT pk.\"product_id\"\n"
+ "FROM product_keys pk\n"
+ "RIGHT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
+ "WHERE pk.\"product_id\" IN (\n"
+ " SELECT p4.\"product_id\"\n"
+ " FROM \"foodmart\".\"product\" p4\n"
+ ")";

final RuleSet rules =
RuleSets.ofList(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE,
CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE,
CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE,
CoreRules.PROJECT_TO_SEMI_JOIN);

final String generated = sql(query).withPostgresql().optimize(rules, null).exec();
assertThat(generated,
matchesPattern("(?s).*GROUP BY\\s+\"[^\"]+\"\\.\"product_id\".*"));
assertThat(generated, not(containsString("GROUP BY \"product_id\"")));
assertPostgresqlSqlValid(generated);
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-7439">[CALCITE-7439]
* RelToSqlConverter should not emit ambiguous GROUP BY after FULL JOIN USING
* with semi-join rewrite.</a>. */
@Test void testPostgresqlRoundTripDistinctFullJoinInSubqueryWithSemiJoinRules() {
final String query = "WITH product_keys AS (\n"
+ " SELECT p.\"product_id\",\n"
+ " (SELECT MAX(p3.\"product_id\")\n"
+ " FROM \"foodmart\".\"product\" p3\n"
+ " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n"
+ " FROM \"foodmart\".\"product\" p\n"
+ ")\n"
+ "SELECT DISTINCT pk.\"product_id\"\n"
+ "FROM product_keys pk\n"
+ "FULL JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
+ "WHERE pk.\"product_id\" IN (\n"
+ " SELECT p4.\"product_id\"\n"
+ " FROM \"foodmart\".\"product\" p4\n"
+ ")";

final RuleSet rules =
RuleSets.ofList(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE,
CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE,
CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE,
CoreRules.PROJECT_TO_SEMI_JOIN);

final String generated = sql(query).withPostgresql().optimize(rules, null).exec();
assertThat(generated, containsString("GROUP BY COALESCE("));
assertThat(generated,
matchesPattern("(?s).*GROUP BY\\s+COALESCE\\(\"[^\"]+\"\\.\"product_id\",\\s*"
+ "\"[^\"]+\"\\.\"product_id\"\\).*"));
assertThat(generated, not(containsString("GROUP BY \"product_id\"")));
assertPostgresqlSqlValid(generated);
}

@Test void testPostgresqlRoundTripRollupJoinUsingQualifiesGroupKey() {
final String query = "SELECT \"product_id\", COUNT(*)\n"
+ "FROM \"foodmart\".\"product\" p1\n"
+ "LEFT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
+ "GROUP BY ROLLUP(\"product_id\")";

final String generated = sql(query).withPostgresql().exec();
assertThat(generated,
matchesPattern("(?s).*GROUP BY\\s+ROLLUP\\(\"[^\"]+\"\\.\"product_id\"\\).*"));
assertThat(generated, not(containsString("GROUP BY ROLLUP(\"product_id\")")));
assertPostgresqlSqlValid(generated);
}

@Test void testPostgresqlRoundTripSingletonCubeJoinUsingQualifiesGroupKey() {
final String query = "SELECT \"product_id\", COUNT(*)\n"
+ "FROM \"foodmart\".\"product\" p1\n"
+ "LEFT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
+ "GROUP BY CUBE(\"product_id\")";

final String generated = sql(query).withPostgresql().exec();
assertThat(generated,
matchesPattern("(?s).*GROUP BY\\s+(?:CUBE|ROLLUP)\\(\"[^\"]+\"\\.\"product_id\"\\).*"));
assertThat(generated, not(containsString("GROUP BY CUBE(\"product_id\")")));
assertThat(generated, not(containsString("GROUP BY ROLLUP(\"product_id\")")));
assertPostgresqlSqlValid(generated);
}

@Test void testNotBetween() {
Sql f = fixture().withConvertletTable(new SqlRexConvertletTable() {
@Override public @Nullable SqlRexConvertlet get(SqlCall call) {
Expand Down
Loading