diff --git a/kotlin-checks-test-sources/src/main/kotlin/checks/EqualsArgumentTypeCheckSample.kt b/kotlin-checks-test-sources/src/main/kotlin/checks/EqualsArgumentTypeCheckSample.kt index 830f6fb2b..2d57b47b5 100644 --- a/kotlin-checks-test-sources/src/main/kotlin/checks/EqualsArgumentTypeCheckSample.kt +++ b/kotlin-checks-test-sources/src/main/kotlin/checks/EqualsArgumentTypeCheckSample.kt @@ -233,4 +233,59 @@ abstract class EqualsArgumentTypeCheckSample { } } + class MyClass9 { + override fun equals(other: Any?): Boolean { // compliant + when(other){ + is MyClass9 -> return true + else -> return false + } + } + } + + class MyClass12 { + override fun equals(other: Any?): Boolean { // compliant + when(other){ + !is MyClass12 -> return false + else -> return true + } + } + } + + class MyClass10 { + override fun equals(other: Any?): Boolean { // Noncompliant + when(other){ + !is MyClass9 -> return false + else -> return true + } + } + } + + class MyClass11 { + override fun equals(other: Any?): Boolean { // compliant + when{ + other is MyClass11 -> return true + else -> return false + } + } + } + + class MyClass13 { + override fun equals(other: Any?): Boolean { // Noncompliant + val me = MyClass13() + when(me){ + is MyClass13 -> return true + else -> return false + } + } + } + + class MyClass14() { + override fun equals(other: Any?): Boolean { // Noncompliant + val me = MyClass14() + when{ + me is MyClass14 -> return true + else -> return false + } + } + } } diff --git a/sonar-kotlin-checks/src/main/java/org/sonarsource/kotlin/checks/EqualsArgumentTypeCheck.kt b/sonar-kotlin-checks/src/main/java/org/sonarsource/kotlin/checks/EqualsArgumentTypeCheck.kt index 7f7444491..e09019096 100644 --- a/sonar-kotlin-checks/src/main/java/org/sonarsource/kotlin/checks/EqualsArgumentTypeCheck.kt +++ b/sonar-kotlin-checks/src/main/java/org/sonarsource/kotlin/checks/EqualsArgumentTypeCheck.kt @@ -30,8 +30,12 @@ import org.jetbrains.kotlin.psi.KtNameReferenceExpression import org.jetbrains.kotlin.psi.KtNamedFunction import org.jetbrains.kotlin.psi.KtParameter import org.jetbrains.kotlin.psi.KtSafeQualifiedExpression +import org.jetbrains.kotlin.psi.KtTypeReference +import org.jetbrains.kotlin.psi.KtWhenConditionIsPattern +import org.jetbrains.kotlin.psi.KtWhenExpression import org.jetbrains.kotlin.psi.psiUtil.collectDescendantsOfType import org.jetbrains.kotlin.psi.psiUtil.containingClass +import org.jetbrains.kotlin.resolve.BindingContext import org.sonar.check.Rule import org.sonarsource.kotlin.api.checks.ANY_TYPE import org.sonarsource.kotlin.api.checks.AbstractCheck @@ -56,27 +60,65 @@ class EqualsArgumentTypeCheck : AbstractCheck() { val klass = function.containingClass() ?: return val parameter = function.valueParameters.first() - val parentNames = klass.superTypeListEntries.mapNotNull { it.typeReference!!.nameForReceiverLabel() } - - if (function.collectDescendantsOfType { parameter.name == (it.leftHandSide as? KtNameReferenceExpression)?.getReferencedName() } - .none { - // typeReference is always present - val name = it.typeReference!!.nameForReceiverLabel() - klass.name == name || parentNames.contains(name) || - it.typeReference!!.determineType(bindingContext) - ?.let { type -> klass.determineType(bindingContext)?.isSupertypeOf(type) } == true - } && - function.collectDescendantsOfType { it.operationToken == KtTokens.EQEQ || it.operationToken == KtTokens.EXCLEQ } - .none { binaryExpression -> isBinaryExpressionCorrect(binaryExpression, parameter, klass) } && - function.collectDescendantsOfType { it.operationReference.getReferencedName() == "as?" } - .none { binaryExpression -> isBinaryExpressionWithTypeCorrect(binaryExpression, parameter, klass) } + if (checkIsExpression(function, parameter, klass, bindingContext) && + checkWhenExpression(function, parameter, klass, bindingContext) && + checkBinaryExpression(function, parameter, klass) && + checkBinaryExpressionRHS(function, parameter, klass) ) { ctx.reportIssue(function.nameIdentifier!!, "Add a type test to this method.") } } + private fun checkBinaryExpressionRHS( + function: KtNamedFunction, + parameter: KtParameter, + klass: KtClass + ) = function.collectDescendantsOfType { it.operationReference.getReferencedName() == "as?" } + .none { binaryExpression -> isBinaryExpressionWithTypeCorrect(binaryExpression, parameter, klass) } + + private fun checkBinaryExpression( + function: KtNamedFunction, + parameter: KtParameter, + klass: KtClass + ) = function.collectDescendantsOfType { it.operationToken == KtTokens.EQEQ || it.operationToken == KtTokens.EXCLEQ } + .none { binaryExpression -> isBinaryExpressionCorrect(binaryExpression, parameter, klass) } + + private fun checkWhenExpression( + function: KtNamedFunction, + parameter: KtParameter, + klass: KtClass, + bindingContext: BindingContext + ) = + function.collectDescendantsOfType { parameter.name == (it.subjectExpression as? KtNameReferenceExpression)?.getReferencedName() } + .none { + it.collectDescendantsOfType().any { whenConditionIsPattern -> + // typeReference is always present + isExpressionCorrectType(whenConditionIsPattern.typeReference!!, klass, bindingContext) + } + } + + private fun checkIsExpression( + function: KtNamedFunction, + parameter: KtParameter, + klass: KtClass, + bindingContext: BindingContext + ) = + function.collectDescendantsOfType { parameter.name == (it.leftHandSide as? KtNameReferenceExpression)?.getReferencedName() } + .none { + // typeReference is always present + isExpressionCorrectType(it.typeReference!!, klass, bindingContext) + } + + private fun isExpressionCorrectType(typeReference: KtTypeReference, klass: KtClass, bindingContext: BindingContext): Boolean { + val name = typeReference.nameForReceiverLabel() + val parentNames = klass.superTypeListEntries.mapNotNull { it.typeReference!!.nameForReceiverLabel() } + return klass.name == name || parentNames.contains(name) || + typeReference.determineType(bindingContext) + ?.let { type -> klass.determineType(bindingContext)?.isSupertypeOf(type) } == true + } + private fun isBinaryExpressionWithTypeCorrect( binaryExpression: KtBinaryExpressionWithTypeRHS, parameter: KtParameter,