feat: HashJoinOp mlir implementation#129
Conversation
HashJoinOp mlir implementation
There was a problem hiding this comment.
Very nice work! left a bunch of comments. Obviously there needs to be import/export/IR tests, but given the draft state, i assume those are in the pipeline :).
The post-join filter is definitely a funny thing here, and principally i agree with your approach (creating a filter node after the hash-join). It's a normalization of the IR, and it avoids introducing redundant operations (i.e. if we'd actually create a post-join filter in the hash-join op).
However, it does present an issue for round-tripping through the IR, which this would break. It's obviously not impossible to fix this (one way would be to add extra logic in the export pass to detect how a filter op is being used, and infer the post-join filter from there)... @ingomueller-net thoughts?
EDIT: guess we aren't the first to pose this question!
https://github.com/substrait-io/substrait/blob/413c7c8c8ea149ea1596c9c3b2e57151d6ce63f7/site/docs/faq.md?plain=1#L7-L12
One way we could add it in there is to have an optional filter region, containing a substrait.filter node:
%2 = subtrait.hash_join %0, %1 on {
^bb0(%arg0: tuple<si32, si32>, %arg1: tuple<si32, si32, si32>):
%3 = field_reference %arg0[0] : tuple<si32, si32> // corresponds to `left`
%4 = field_reference %arg1[0] : tuple<si32, si32, si32> // corresponds to `right`
%5 = call @cmp(%3, %4) : (si32, si32) -> si1 // corresponds to `custom_function_reference`
// ... or ...
%5 = compare not_distinct_from %3, %4 : (si32, si32) -> si1 // corresponds to `simple`
yield %5 : si1
} filter {
^bb0(%arg0: tuple<si32, si32>, %arg1: tuple<si32, si32, si32>):
%res = substrait.filter ...
yield %res : si1
}| Substrait_ExpressionType:$lhs, | ||
| Substrait_ExpressionType:$rhs, | ||
| OptionalAttr<SimpleComparisonType>:$comparison_type, | ||
| OptionalAttr<UI32Attr>:$custom_function_id |
There was a problem hiding this comment.
If custom_function_id is deliberately not implemented, i'd say remove this argument and write a TODO in the description field.
| let results = (outs SI1:$result); | ||
|
|
||
| let assemblyFormat = [{ | ||
| $comparison_type $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($result) |
There was a problem hiding this comment.
`(` type($lhs) `,` type($rhs) `)` `->` type($result)
could be
functional-type(operands, $result)
| let assemblyFormat = [{ | ||
| $join_type $left `,` $right | ||
| (`advanced_extension` `` $advanced_extension^)? | ||
| attr-dict `:` type($left) `,` type($right) `->` type($result) $condition |
There was a problem hiding this comment.
as above (functional-type).
| } | ||
|
|
||
| LogicalResult KeyComparisonOp::verify() { | ||
| auto &res = |
There was a problem hiding this comment.
Since this may be a bit unintuitive at first (i.e. one might expect that the comparison op should always have identical types for both operands), a comment might be warranted here.
Do we have some documentation about substraits stance here? That is, this code is assuming that the comparison op is able to perform comparisons against types that are "cast-compatible" (int <=> decimal, string <=> varchar).
| if (failed(res)) | ||
| return res; | ||
|
|
||
| return success(); |
There was a problem hiding this comment.
Could just return res; here.
| return op->emitOpError("missing join condition"); | ||
| } | ||
|
|
||
| Block &conditionBlock = op.getCondition().front(); |
There was a problem hiding this comment.
As above (op.getBody()).
| if (!compareOp) { | ||
| return op->emitOpError("join condition must be a KeyComparisonOp"); | ||
| } |
There was a problem hiding this comment.
This could be removed, since you're already checking for this in your verifier.
nit: remove braces.
| if (auto leftFieldRef = | ||
| dyn_cast_or_null<FieldReferenceOp>(leftKey.getDefiningOp())) { | ||
| leftKeyExpr = exportOperation(leftFieldRef); | ||
| } else { | ||
| return op->emitOpError() << "left key must be a field reference"; | ||
| } | ||
|
|
||
| FailureOr<std::unique_ptr<Expression>> rightKeyExpr; | ||
| if (auto rightFieldRef = | ||
| dyn_cast_or_null<FieldReferenceOp>(rightKey.getDefiningOp())) { | ||
| rightKeyExpr = exportOperation(rightFieldRef); | ||
| } else { | ||
| return op->emitOpError() << "right key must be a field reference"; | ||
| } |
There was a problem hiding this comment.
This logic should be moved to a verifier of the KeyComparisonOp (i.e. set hasVerifier = 1 for the op).
| if (failed(hashJoinOp)) { | ||
| return failure(); | ||
| } |
| return mlir::emitError(builder.getLoc(), | ||
| "custom comparison functions not yet supported"); |
There was a problem hiding this comment.
Nit: change to "custom comparison functions for hash_join not yet supported" to make it a bit more clear, when one has a very large substrait input file, and it may be a bit hard to decipher what the error message refers to.
This commit adds MLIR implementation for HashJoinOp
This tries to follow a similar MLIR pattern as in issue #97
Skipped work:
Doesn't implement custom_function_id. It is not needed by us at the moment, so left a TODO for now.