Skip to content

Commit fc48e1c

Browse files
committed
operation
1 parent 36f685c commit fc48e1c

2 files changed

Lines changed: 180 additions & 86 deletions

File tree

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 58 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ import M2
271271

272272
private module Input3 implements InputSig3 {
273273
private import rust as Rust
274+
private import codeql.rust.elements.internal.OperationImpl::Impl as OperationImpl
274275

275276
predicate cacheRevRef() {
276277
Stages::TypeInferenceStage::ref()
@@ -420,7 +421,7 @@ private module Input3 implements InputSig3 {
420421

421422
class CallResolutionContext = FunctionCallMatchingInput::AccessEnvironment;
422423

423-
class Callable extends FunctionCallMatchingInput::Declaration {
424+
final class Callable extends FunctionCallMatchingInput::Declaration {
424425
TypeMention getAdditionalTypeParameterConstraint(TypeParameter tp) {
425426
result =
426427
tp.(TypeParamTypeParameter)
@@ -527,6 +528,60 @@ private module Input3 implements InputSig3 {
527528
)
528529
}
529530

531+
class Operator extends Callable {
532+
private Method getSelfOrImpl() {
533+
result = f
534+
or
535+
f.implements(result)
536+
}
537+
538+
pragma[nomagic]
539+
private predicate borrowsAt(int pos) {
540+
exists(TraitItemNode t, string path, string method |
541+
this.getSelfOrImpl() = t.getAssocItem(method) and
542+
path = t.getCanonicalPath(_) and
543+
exists(int borrows | OperationImpl::isOverloaded(_, _, path, method, borrows) |
544+
pos = 0 and borrows >= 1
545+
or
546+
pos = 1 and
547+
borrows >= 2
548+
)
549+
)
550+
}
551+
552+
pragma[nomagic]
553+
private predicate derefsReturn() { this.getSelfOrImpl() = any(DerefTrait t).getDerefFunction() }
554+
555+
Type getParameterType(int pos, TypePath path) {
556+
exists(TypePath path0 | result = super.getParameterType(pos, path0) |
557+
if this.borrowsAt(pos) then path0.isCons(getRefTypeParameter(_), path) else path0 = path
558+
)
559+
}
560+
561+
Type getReturnType(TypePath path) {
562+
exists(TypePath path0 | result = super.getReturnType(path0) |
563+
if this.derefsReturn() then path0.isCons(getRefTypeParameter(_), path) else path0 = path
564+
)
565+
}
566+
}
567+
568+
class Operation extends AssocFunctionResolution::OperationAssocFunctionCall {
569+
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { none() }
570+
571+
AstNode getOperand(int i) {
572+
exists(FunctionPosition pos |
573+
result = this.getNodeAt(pos) and
574+
i = pos.asPosition()
575+
)
576+
}
577+
578+
Operator getTarget() {
579+
exists(ImplOrTraitItemNode i |
580+
result.isAssocFunction(i, this.resolveCallTarget(i, _, _, _), false) // mutual recursion
581+
)
582+
}
583+
}
584+
530585
predicate inferStepCertain(AstNode n1, TypePath path1, AstNode n2, TypePath path2) {
531586
n1 =
532587
any(IdentPat ip |
@@ -701,11 +756,7 @@ private module Input3 implements InputSig3 {
701756
or
702757
result = inferAssignmentOperationType(n, path)
703758
or
704-
exists(FunctionPosition pos | pos.isReturn() |
705-
result = inferConstructionType(n, pos, path)
706-
or
707-
result = inferOperationType(n, pos, path)
708-
)
759+
exists(FunctionPosition pos | pos.isReturn() | result = inferConstructionType(n, pos, path))
709760
or
710761
result = inferTryExprType(n, path)
711762
or
@@ -727,11 +778,7 @@ private module Input3 implements InputSig3 {
727778
}
728779

729780
Type inferTypeTopDown(AstNode n, TypePath path) {
730-
exists(FunctionPosition pos | not pos.isReturn() |
731-
result = inferConstructionType(n, pos, path)
732-
or
733-
result = inferOperationType(n, pos, path)
734-
)
781+
exists(FunctionPosition pos | not pos.isReturn() | result = inferConstructionType(n, pos, path))
735782
}
736783
}
737784

@@ -3149,81 +3196,6 @@ private Type inferUnknownType(AstNode n, TypePath path) {
31493196
)
31503197
}
31513198

3152-
/**
3153-
* A matching configuration for resolving types of operations like `a + b`.
3154-
*/
3155-
private module OperationMatchingInput implements MatchingInputSig {
3156-
private import codeql.rust.elements.internal.OperationImpl::Impl as OperationImpl
3157-
import FunctionPositionMatchingInput
3158-
3159-
class Declaration extends FunctionCallMatchingInput::Declaration {
3160-
private Method getSelfOrImpl() {
3161-
result = f
3162-
or
3163-
f.implements(result)
3164-
}
3165-
3166-
pragma[nomagic]
3167-
private predicate borrowsAt(FunctionPosition pos) {
3168-
exists(TraitItemNode t, string path, string method |
3169-
this.getSelfOrImpl() = t.getAssocItem(method) and
3170-
path = t.getCanonicalPath(_) and
3171-
exists(int borrows | OperationImpl::isOverloaded(_, _, path, method, borrows) |
3172-
pos.asPosition() = 0 and borrows >= 1
3173-
or
3174-
pos.asPosition() = 1 and
3175-
borrows >= 2
3176-
)
3177-
)
3178-
}
3179-
3180-
pragma[nomagic]
3181-
private predicate derefsReturn() { this.getSelfOrImpl() = any(DerefTrait t).getDerefFunction() }
3182-
3183-
Type getDeclaredType(FunctionPosition pos, TypePath path) {
3184-
exists(TypePath path0 |
3185-
result = super.getParameterType(pos.asPosition(), path0)
3186-
or
3187-
pos.isReturn() and
3188-
result = super.getReturnType(path0)
3189-
|
3190-
if
3191-
this.borrowsAt(pos)
3192-
or
3193-
pos.isReturn() and this.derefsReturn()
3194-
then path0.isCons(getRefTypeParameter(_), path)
3195-
else path0 = path
3196-
)
3197-
}
3198-
}
3199-
3200-
class Access extends AssocFunctionResolution::OperationAssocFunctionCall {
3201-
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { none() }
3202-
3203-
pragma[nomagic]
3204-
Type getInferredType(FunctionPosition pos, TypePath path) {
3205-
result = inferType(this.getNodeAt(pos), path)
3206-
}
3207-
3208-
Declaration getTarget() {
3209-
exists(ImplOrTraitItemNode i |
3210-
result.isAssocFunction(i, this.resolveCallTarget(i, _, _, _), false) // mutual recursion
3211-
)
3212-
}
3213-
}
3214-
}
3215-
3216-
private module OperationMatching = Matching<OperationMatchingInput>;
3217-
3218-
pragma[nomagic]
3219-
private Type inferOperationType(AstNode n, FunctionPosition pos, TypePath path) {
3220-
exists(OperationMatchingInput::Access a |
3221-
n = a.getNodeAt(pos) and
3222-
result = OperationMatching::inferAccessType(a, pos, path) and
3223-
if pos.asPosition() = 0 then not path.isEmpty() else any()
3224-
)
3225-
}
3226-
32273199
pragma[nomagic]
32283200
private Type getFieldExprLookupType(FieldExpr fe, string name, DerefChain derefChain) {
32293201
exists(TypePath path |

shared/typeinference/codeql/typeinference/internal/TypeInference.qll

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2336,6 +2336,61 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
23362336
*/
23372337
Type inferMemberAccessReceiverTypeContextual(AstNode n, TypePath path);
23382338

2339+
/** An operator. */
2340+
class Operator {
2341+
/**
2342+
* Gets the type parameter at position `ppos` of this operator, if any.
2343+
*
2344+
* This should include type parameters declared on the operator itself,
2345+
* as well as type parameters declared on the enclosing declaration(s).
2346+
*/
2347+
TypeParameter getTypeParameter(TypeParameterPosition ppos);
2348+
2349+
/**
2350+
* Gets an additional type parameter constraint for the given type parameter,
2351+
* which applies to this operator. For example, in Rust, a function can apply
2352+
* additional constraints on type parameters belonging to the `impl` block
2353+
* that the function is defined in:
2354+
*
2355+
* ```rust
2356+
* impl<T> MyThing<T> {
2357+
* fn foo(self) where T: MyTrait { ... }
2358+
* }
2359+
*/
2360+
TypeMention getAdditionalTypeParameterConstraint(TypeParameter tp);
2361+
2362+
/**
2363+
* Gets the declared type of the `i`th parameter of this operator at `path`.
2364+
*
2365+
* This should also include (possibly implicit) `this`/`self` parameters,
2366+
* using index `0`.
2367+
*/
2368+
Type getParameterType(int i, TypePath path);
2369+
2370+
/** Gets the declared return type of this operator at `path`. */
2371+
Type getReturnType(TypePath path);
2372+
2373+
/** Gets a textual representation of this operator. */
2374+
string toString();
2375+
2376+
/** Gets the location of this operator. */
2377+
Location getLocation();
2378+
}
2379+
2380+
/** An overloaded operation, for example `a + b`. */
2381+
class Operation extends AstNode {
2382+
/** Gets the explicit type argument at position `apos` and `path` for this call, if any. */
2383+
Type getTypeArgument(TypeArgumentPosition apos, TypePath path);
2384+
2385+
/**
2386+
* Gets the AST node corresponding to the `i`th operand.
2387+
*/
2388+
AstNode getOperand(int i);
2389+
2390+
/** Gets the target of this operation in the given resolution context. */
2391+
Operator getTarget();
2392+
}
2393+
23392394
/**
23402395
* Holds if `n1` having certain type `t` at `path1` implies that `n2` has
23412396
* certain type `t` at `path2`, but not necessarily the other way around.
@@ -2545,6 +2600,12 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
25452600
or
25462601
result = inferMemberAccessReceiverTypeContextual(n, path)
25472602
or
2603+
exists(Operation o, int pos |
2604+
n = o.getOperand(pos) and
2605+
result = OperationMatching::inferAccessType(o, pos, path) and
2606+
hasUnknownType(n)
2607+
)
2608+
or
25482609
exists(TypePath path1, AstNode n2, TypePath path2, TypePath suffix |
25492610
result = inferType(n2, path2.appendInverse(suffix)) and
25502611
path = path1.append(suffix) and
@@ -2577,6 +2638,8 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
25772638
or
25782639
result = inferMemberAccessType(n, path)
25792640
or
2641+
result = inferOperationReturnType(n, path)
2642+
or
25802643
// contextual typing: only propagate type information from surrounding context
25812644
// into a node which has an explicitly unknown type
25822645
exists(TypePath prefix |
@@ -2794,6 +2857,65 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
27942857
hasUnknownType(receiver)
27952858
}
27962859

2860+
/**
2861+
* A matching configuration for resolving types of operations.
2862+
*/
2863+
private module OperationMatchingInput implements MatchingInputSig {
2864+
class DeclarationPosition = CallMatchingInput::DeclarationPosition;
2865+
2866+
class AccessPosition = CallMatchingInput::AccessPosition;
2867+
2868+
predicate accessDeclarationPositionMatch =
2869+
CallMatchingInput::accessDeclarationPositionMatch/2;
2870+
2871+
additional predicate getReturnPosition = CallMatchingInput::getReturnPosition/0;
2872+
2873+
final private class OperatorFinal = Operator;
2874+
2875+
class Declaration extends OperatorFinal {
2876+
Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
2877+
result = this.getParameterType(dpos, path)
2878+
or
2879+
dpos = getReturnPosition() and
2880+
result = this.getReturnType(path)
2881+
}
2882+
}
2883+
2884+
bindingset[decl]
2885+
TypeMention getATypeParameterConstraint(TypeParameter tp, Declaration decl) {
2886+
result = Input2::getATypeParameterConstraint(tp) and
2887+
exists(decl)
2888+
or
2889+
result = decl.getAdditionalTypeParameterConstraint(tp)
2890+
}
2891+
2892+
final private class OperationFinal = Operation;
2893+
2894+
class Access extends OperationFinal {
2895+
pragma[nomagic]
2896+
private Type getInferredResultType(AccessPosition apos, TypePath path) {
2897+
result = inferType(this, path) and
2898+
apos = getReturnPosition()
2899+
}
2900+
2901+
Type getInferredType(AccessPosition apos, TypePath path) {
2902+
result = inferType(this.getOperand(apos), path)
2903+
or
2904+
result = this.getInferredResultType(apos, path)
2905+
}
2906+
2907+
Declaration getTarget() { result = super.getTarget() }
2908+
}
2909+
}
2910+
2911+
private module OperationMatching = Matching<OperationMatchingInput>;
2912+
2913+
pragma[nomagic]
2914+
private Type inferOperationReturnType(Operation op, TypePath path) {
2915+
result =
2916+
OperationMatching::inferAccessType(op, OperationMatchingInput::getReturnPosition(), path)
2917+
}
2918+
27972919
/**
27982920
* Gets the inferred root type of `n`, if any.
27992921
*/

0 commit comments

Comments
 (0)