Skip to content

Commit 8fcf793

Browse files
committed
Fix @ModifyConstant's expected method signature for class types. Closes #2231, closes #2258
1 parent ceb6439 commit 8fcf793

File tree

2 files changed

+91
-24
lines changed

2 files changed

+91
-24
lines changed

src/main/kotlin/platform/mixin/handlers/ModifyConstantHandler.kt

Lines changed: 86 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ import com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.InjectionPoint
2525
import com.demonwav.mcdev.platform.mixin.inspection.injector.MethodSignature
2626
import com.demonwav.mcdev.platform.mixin.inspection.injector.ParameterGroup
2727
import com.demonwav.mcdev.util.findAnnotations
28+
import com.intellij.openapi.project.Project
29+
import com.intellij.psi.JavaPsiFacade
2830
import com.intellij.psi.PsiAnnotation
31+
import com.intellij.psi.PsiElement
2932
import com.intellij.psi.PsiManager
3033
import com.intellij.psi.PsiMethod
3134
import com.intellij.psi.PsiType
@@ -61,6 +64,8 @@ class ModifyConstantHandler : InjectorAnnotationHandler() {
6164
Opcodes.IFGE,
6265
Opcodes.IFGT,
6366
Opcodes.IFLE,
67+
Opcodes.CHECKCAST,
68+
Opcodes.INSTANCEOF,
6469
)
6570

6671
private fun getConstantInfos(modifyConstant: PsiAnnotation): List<ConstantInjectionPoint.ConstantInfo>? {
@@ -103,30 +108,89 @@ class ModifyConstantHandler : InjectorAnnotationHandler() {
103108
}
104109

105110
val psiManager = PsiManager.getInstance(annotation.project)
106-
return constantInfos.asSequence().map {
107-
when (it.constant) {
108-
null -> PsiType.getJavaLangObject(psiManager, annotation.resolveScope)
109-
is Int -> PsiTypes.intType()
110-
is Float -> PsiTypes.floatType()
111-
is Long -> PsiTypes.longType()
112-
is Double -> PsiTypes.doubleType()
113-
is String -> PsiType.getJavaLangString(psiManager, annotation.resolveScope)
114-
is Type -> PsiType.getJavaLangClass(psiManager, annotation.resolveScope)
115-
else -> throw IllegalStateException("Unknown constant type: ${it.constant.javaClass.name}")
111+
return constantInfos.asSequence()
112+
.distinctBy { it.constant?.javaClass }
113+
.flatMap {
114+
when (it.constant) {
115+
null -> sequenceOf(
116+
makeMethodSignature(annotation.project, targetClass, targetMethod, PsiType.getJavaLangObject(psiManager, annotation.resolveScope))
117+
)
118+
is Int -> sequenceOf(
119+
makeMethodSignature(annotation.project, targetClass, targetMethod, PsiTypes.intType()),
120+
makeMethodSignature(annotation.project, targetClass, targetMethod, PsiTypes.booleanType()),
121+
makeMethodSignature(annotation.project, targetClass, targetMethod, PsiTypes.byteType()),
122+
makeMethodSignature(annotation.project, targetClass, targetMethod, PsiTypes.charType()),
123+
makeMethodSignature(annotation.project, targetClass, targetMethod, PsiTypes.shortType()),
124+
)
125+
is Long -> sequenceOf(
126+
makeMethodSignature(annotation.project, targetClass, targetMethod, PsiTypes.longType())
127+
)
128+
is Float -> sequenceOf(
129+
makeMethodSignature(annotation.project, targetClass, targetMethod, PsiTypes.floatType())
130+
)
131+
is Double -> sequenceOf(
132+
makeMethodSignature(annotation.project, targetClass, targetMethod, PsiTypes.doubleType())
133+
)
134+
is String -> sequenceOf(
135+
makeMethodSignature(annotation.project, targetClass, targetMethod, PsiType.getJavaLangString(psiManager, annotation.resolveScope))
136+
)
137+
is Type -> sequenceOf(
138+
makeTypeCheckMethodSignature(annotation.project, psiManager, annotation, targetClass, targetMethod, getClassType(psiManager, annotation)),
139+
makeTypeCheckMethodSignature(annotation.project, psiManager, annotation, targetClass, targetMethod, PsiTypes.booleanType()),
140+
)
141+
else -> throw IllegalStateException("Unknown constant type: ${it.constant.javaClass.name}")
142+
}
116143
}
117-
}.distinct().map { type ->
118-
MethodSignature(
119-
listOf(
120-
ParameterGroup(listOf(sanitizedParameter(type, "constant"))),
121-
ParameterGroup(
122-
collectTargetMethodParameters(annotation.project, targetClass, targetMethod),
123-
isVarargs = true,
124-
required = ParameterGroup.RequiredLevel.OPTIONAL,
125-
),
144+
.toList()
145+
}
146+
147+
private fun makeMethodSignature(
148+
project: Project,
149+
targetClass: ClassNode,
150+
targetMethod: MethodNode,
151+
type: PsiType,
152+
): MethodSignature {
153+
return MethodSignature(
154+
listOf(
155+
ParameterGroup(listOf(sanitizedParameter(type, "constant"))),
156+
ParameterGroup(
157+
collectTargetMethodParameters(project, targetClass, targetMethod),
158+
isVarargs = true,
159+
required = ParameterGroup.RequiredLevel.OPTIONAL,
126160
),
127-
type,
128-
)
129-
}.toList()
161+
),
162+
type,
163+
)
164+
}
165+
166+
private fun makeTypeCheckMethodSignature(
167+
project: Project,
168+
psiManager: PsiManager,
169+
context: PsiElement,
170+
targetClass: ClassNode,
171+
targetMethod: MethodNode,
172+
returnType: PsiType,
173+
): MethodSignature {
174+
return MethodSignature(
175+
listOf(
176+
ParameterGroup(
177+
listOf(
178+
sanitizedParameter(PsiType.getJavaLangObject(psiManager, context.resolveScope), "instance"),
179+
sanitizedParameter(getClassType(psiManager, context), "type"),
180+
)
181+
),
182+
ParameterGroup(
183+
collectTargetMethodParameters(project, targetClass, targetMethod),
184+
isVarargs = true,
185+
required = ParameterGroup.RequiredLevel.OPTIONAL,
186+
),
187+
),
188+
returnType,
189+
)
190+
}
191+
192+
private fun getClassType(psiManager: PsiManager, context: PsiElement): PsiType {
193+
return JavaPsiFacade.getElementFactory(psiManager.project).createTypeFromText("java.lang.Class<?>", context)
130194
}
131195

132196
override fun isInsnAllowed(insn: AbstractInsnNode, decorations: Map<String, Any?>): Boolean {

src/main/kotlin/platform/mixin/handlers/injectionPoint/AtResolver.kt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ import com.demonwav.mcdev.platform.mixin.util.memberReference
3737
import com.demonwav.mcdev.util.computeStringArray
3838
import com.demonwav.mcdev.util.constantStringValue
3939
import com.demonwav.mcdev.util.constantValue
40+
import com.demonwav.mcdev.util.descriptor
4041
import com.demonwav.mcdev.util.equivalentTo
4142
import com.demonwav.mcdev.util.findMethods
42-
import com.demonwav.mcdev.util.fullQualifiedName
43+
import com.demonwav.mcdev.util.internalName
4344
import com.intellij.codeInsight.lookup.LookupElementBuilder
4445
import com.intellij.psi.JavaPsiFacade
4546
import com.intellij.psi.PsiAnnotation
@@ -56,6 +57,7 @@ import com.intellij.psi.PsiParameterListOwner
5657
import com.intellij.psi.PsiQualifiedReference
5758
import com.intellij.psi.PsiReference
5859
import com.intellij.psi.PsiReferenceExpression
60+
import com.intellij.psi.PsiType
5961
import com.intellij.psi.search.GlobalSearchScope
6062
import com.intellij.psi.util.PsiUtil
6163
import com.intellij.psi.util.parents
@@ -140,7 +142,8 @@ class AtResolver(
140142
return value.initializers.map { valueToString(it) ?: return null }.joinToString(",")
141143
}
142144
return when (val constant = value.constantValue) {
143-
is PsiClassType -> constant.fullQualifiedName?.replace('.', '/')
145+
is PsiClassType -> constant.resolve()?.internalName
146+
is PsiType -> constant.descriptor
144147
null -> when (value) {
145148
is PsiReferenceExpression -> value.referenceName
146149
else -> null

0 commit comments

Comments
 (0)