@@ -25,7 +25,10 @@ import com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.InjectionPoint
2525import com.demonwav.mcdev.platform.mixin.inspection.injector.MethodSignature
2626import com.demonwav.mcdev.platform.mixin.inspection.injector.ParameterGroup
2727import com.demonwav.mcdev.util.findAnnotations
28+ import com.intellij.openapi.project.Project
29+ import com.intellij.psi.JavaPsiFacade
2830import com.intellij.psi.PsiAnnotation
31+ import com.intellij.psi.PsiElement
2932import com.intellij.psi.PsiManager
3033import com.intellij.psi.PsiMethod
3134import com.intellij.psi.PsiType
@@ -61,6 +64,8 @@ class ModifyConstantHandler : InjectorAnnotationHandler() {
6164 Opcodes .IFGE ,
6265 Opcodes .IFGT ,
6366 Opcodes .IFLE ,
67+ Opcodes .CHECKCAST ,
68+ Opcodes .INSTANCEOF ,
6469 )
6570
6671 private fun getConstantInfos (modifyConstant : PsiAnnotation ): List <ConstantInjectionPoint .ConstantInfo >? {
@@ -103,30 +108,89 @@ class ModifyConstantHandler : InjectorAnnotationHandler() {
103108 }
104109
105110 val psiManager = PsiManager .getInstance(annotation.project)
106- return constantInfos.asSequence().map {
107- when (it.constant) {
108- null -> PsiType .getJavaLangObject(psiManager, annotation.resolveScope)
109- is Int -> PsiTypes .intType()
110- is Float -> PsiTypes .floatType()
111- is Long -> PsiTypes .longType()
112- is Double -> PsiTypes .doubleType()
113- is String -> PsiType .getJavaLangString(psiManager, annotation.resolveScope)
114- is Type -> PsiType .getJavaLangClass(psiManager, annotation.resolveScope)
115- else -> throw IllegalStateException (" Unknown constant type: ${it.constant.javaClass.name} " )
111+ return constantInfos.asSequence()
112+ .distinctBy { it.constant?.javaClass }
113+ .flatMap {
114+ when (it.constant) {
115+ null -> sequenceOf(
116+ makeMethodSignature(annotation.project, targetClass, targetMethod, PsiType .getJavaLangObject(psiManager, annotation.resolveScope))
117+ )
118+ is Int -> sequenceOf(
119+ makeMethodSignature(annotation.project, targetClass, targetMethod, PsiTypes .intType()),
120+ makeMethodSignature(annotation.project, targetClass, targetMethod, PsiTypes .booleanType()),
121+ makeMethodSignature(annotation.project, targetClass, targetMethod, PsiTypes .byteType()),
122+ makeMethodSignature(annotation.project, targetClass, targetMethod, PsiTypes .charType()),
123+ makeMethodSignature(annotation.project, targetClass, targetMethod, PsiTypes .shortType()),
124+ )
125+ is Long -> sequenceOf(
126+ makeMethodSignature(annotation.project, targetClass, targetMethod, PsiTypes .longType())
127+ )
128+ is Float -> sequenceOf(
129+ makeMethodSignature(annotation.project, targetClass, targetMethod, PsiTypes .floatType())
130+ )
131+ is Double -> sequenceOf(
132+ makeMethodSignature(annotation.project, targetClass, targetMethod, PsiTypes .doubleType())
133+ )
134+ is String -> sequenceOf(
135+ makeMethodSignature(annotation.project, targetClass, targetMethod, PsiType .getJavaLangString(psiManager, annotation.resolveScope))
136+ )
137+ is Type -> sequenceOf(
138+ makeTypeCheckMethodSignature(annotation.project, psiManager, annotation, targetClass, targetMethod, getClassType(psiManager, annotation)),
139+ makeTypeCheckMethodSignature(annotation.project, psiManager, annotation, targetClass, targetMethod, PsiTypes .booleanType()),
140+ )
141+ else -> throw IllegalStateException (" Unknown constant type: ${it.constant.javaClass.name} " )
142+ }
116143 }
117- }.distinct().map { type ->
118- MethodSignature (
119- listOf (
120- ParameterGroup (listOf (sanitizedParameter(type, " constant" ))),
121- ParameterGroup (
122- collectTargetMethodParameters(annotation.project, targetClass, targetMethod),
123- isVarargs = true ,
124- required = ParameterGroup .RequiredLevel .OPTIONAL ,
125- ),
144+ .toList()
145+ }
146+
147+ private fun makeMethodSignature (
148+ project : Project ,
149+ targetClass : ClassNode ,
150+ targetMethod : MethodNode ,
151+ type : PsiType ,
152+ ): MethodSignature {
153+ return MethodSignature (
154+ listOf (
155+ ParameterGroup (listOf (sanitizedParameter(type, " constant" ))),
156+ ParameterGroup (
157+ collectTargetMethodParameters(project, targetClass, targetMethod),
158+ isVarargs = true ,
159+ required = ParameterGroup .RequiredLevel .OPTIONAL ,
126160 ),
127- type,
128- )
129- }.toList()
161+ ),
162+ type,
163+ )
164+ }
165+
166+ private fun makeTypeCheckMethodSignature (
167+ project : Project ,
168+ psiManager : PsiManager ,
169+ context : PsiElement ,
170+ targetClass : ClassNode ,
171+ targetMethod : MethodNode ,
172+ returnType : PsiType ,
173+ ): MethodSignature {
174+ return MethodSignature (
175+ listOf (
176+ ParameterGroup (
177+ listOf (
178+ sanitizedParameter(PsiType .getJavaLangObject(psiManager, context.resolveScope), " instance" ),
179+ sanitizedParameter(getClassType(psiManager, context), " type" ),
180+ )
181+ ),
182+ ParameterGroup (
183+ collectTargetMethodParameters(project, targetClass, targetMethod),
184+ isVarargs = true ,
185+ required = ParameterGroup .RequiredLevel .OPTIONAL ,
186+ ),
187+ ),
188+ returnType,
189+ )
190+ }
191+
192+ private fun getClassType (psiManager : PsiManager , context : PsiElement ): PsiType {
193+ return JavaPsiFacade .getElementFactory(psiManager.project).createTypeFromText(" java.lang.Class<?>" , context)
130194 }
131195
132196 override fun isInsnAllowed (insn : AbstractInsnNode , decorations : Map <String , Any ?>): Boolean {
0 commit comments