Skip to content

Commit

Permalink
Add check to prevent injecting assisted factories for HiltViewModels
Browse files Browse the repository at this point in the history
RELNOTES=Add check to prevent injecting assisted factories for HiltViewModels
PiperOrigin-RevId: 568868720
  • Loading branch information
kuanyingchou authored and Dagger Team committed Oct 10, 2023
1 parent 8327177 commit 84c034f
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ kt_jvm_library(
"//:spi",
"//java/dagger/hilt/android/processor/internal:android_classnames",
"//java/dagger/hilt/processor/internal:dagger_models",
"//java/dagger/internal/codegen/xprocessing",
"//third_party/java/auto:service",
"//third_party/java/guava/graph",
"//third_party/java/javapoet",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,45 @@
* limitations under the License.
*/

@file:OptIn(ExperimentalProcessingApi::class)

package dagger.hilt.android.processor.internal.viewmodel

import androidx.room.compiler.processing.ExperimentalProcessingApi
import androidx.room.compiler.processing.XMethodElement
import androidx.room.compiler.processing.XProcessingEnv
import androidx.room.compiler.processing.XProcessingEnv.Companion.create
import androidx.room.compiler.processing.XType
import androidx.room.compiler.processing.XTypeElement
import androidx.room.compiler.processing.compat.XConverters.toXProcessing
import com.google.auto.service.AutoService
import com.google.common.graph.EndpointPair
import com.google.common.graph.ImmutableNetwork
import dagger.hilt.android.processor.internal.AndroidClassNames
import dagger.hilt.processor.internal.getQualifiedName
import dagger.hilt.processor.internal.hasAnnotation
import dagger.internal.codegen.xprocessing.XTypeElements
import dagger.spi.model.Binding
import dagger.spi.model.BindingGraph
import dagger.spi.model.BindingGraph.Edge
import dagger.spi.model.BindingGraph.Node
import dagger.spi.model.BindingGraphPlugin
import dagger.spi.model.BindingKind
import dagger.spi.model.DaggerProcessingEnv
import dagger.spi.model.DaggerType
import dagger.spi.model.DiagnosticReporter
import javax.tools.Diagnostic.Kind

/** Plugin to validate users do not inject @HiltViewModel classes. */
@AutoService(BindingGraphPlugin::class)
class ViewModelValidationPlugin : BindingGraphPlugin {

private lateinit var env: XProcessingEnv

override fun init(processingEnv: DaggerProcessingEnv, options: MutableMap<String, String>) {
env = processingEnv.toXProcessingEnv()
}

override fun visitGraph(bindingGraph: BindingGraph, diagnosticReporter: DiagnosticReporter) {
if (bindingGraph.rootComponentNode().isSubcomponent()) {
// This check does not work with partial graphs since it needs to take into account the source
Expand All @@ -46,9 +65,10 @@ class ViewModelValidationPlugin : BindingGraphPlugin {
val pair: EndpointPair<Node> = network.incidentNodes(edge)
val target: Node = pair.target()
val source: Node = pair.source()
if (
target is Binding && isHiltViewModelBinding(target) && !isInternalHiltViewModelUsage(source)
) {
if (target !is Binding) {
return@forEach
}
if (isHiltViewModelBinding(target) && !isInternalHiltViewModelUsage(source)) {
diagnosticReporter.reportDependency(
Kind.ERROR,
edge,
Expand All @@ -57,6 +77,17 @@ class ViewModelValidationPlugin : BindingGraphPlugin {
"(e.g. ViewModelProvider) instead." +
"\nInjected ViewModel: ${target.key().type()}\n"
)
} else if (
isViewModelAssistedFactory(target) && !isInternalViewModelAssistedFactoryUsage(source)
) {
diagnosticReporter.reportDependency(
Kind.ERROR,
edge,
"\nInjection of an assisted factory for Hilt ViewModel is prohibited since it " +
"can not be used to create a ViewModel instance correctly.\nAccess the ViewModel via " +
"the Android APIs (e.g. ViewModelProvider) instead." +
"\nInjected factory: ${target.key().type()}\n"
)
}
}
}
Expand Down Expand Up @@ -84,4 +115,54 @@ class ViewModelValidationPlugin : BindingGraphPlugin {
AndroidClassNames.HILT_VIEW_MODEL_MAP_QUALIFIER.canonicalName() &&
source.key().multibindingContributionIdentifier().isPresent()
}

private fun isViewModelAssistedFactory(target: Binding): Boolean {
if (target.kind() != BindingKind.ASSISTED_FACTORY) return false
val factoryType = target.key().type()
return getAssistedInjectTypeElement(factoryType.toXType(env)).hasAnnotation(AndroidClassNames.HILT_VIEW_MODEL)
}

private fun getAssistedInjectTypeElement(factoryType: XType): XTypeElement =
getAssistedFactoryMethods(factoryType.typeElement)
.single()
.asMemberOf(factoryType)
.returnType
.typeElement!!

private fun getAssistedFactoryMethods(factory: XTypeElement?): List<XMethodElement> {
return XTypeElements.getAllNonPrivateInstanceMethods(factory)
.filter { it.isAbstract() }
.filter { !it.isJavaDefault() }
}

private fun isInternalViewModelAssistedFactoryUsage(source: Node): Boolean {
// We expect the only usage of the assisted factory for a Hilt ViewModel is in the
// code we generate:
// @Binds
// @IntoMap
// @StringKey(...)
// @HiltViewModelAssistedMap
// public abstract Object bind(FooFactory factory);
return source is Binding &&
source.key().qualifier().isPresent() &&
source.key().qualifier().get().getQualifiedName() ==
AndroidClassNames.HILT_VIEW_MODEL_ASSISTED_FACTORY_MAP_QUALIFIER.canonicalName() &&
source.key().multibindingContributionIdentifier().isPresent()
}
}

private fun DaggerType.toXType(processingEnv: XProcessingEnv): XType {
return when (backend()) {
DaggerProcessingEnv.Backend.JAVAC -> javac().toXProcessing(processingEnv)
DaggerProcessingEnv.Backend.KSP -> ksp().toXProcessing(processingEnv)
else -> error("Backend ${ backend() } not supported yet.")
}
}

private fun DaggerProcessingEnv.toXProcessingEnv(): XProcessingEnv {
return when (backend()) {
DaggerProcessingEnv.Backend.JAVAC -> create(javac())
DaggerProcessingEnv.Backend.KSP -> create(ksp(), resolver())
else -> error("Backend ${ backend() } not supported yet.")
}
}
32 changes: 32 additions & 0 deletions javatests/dagger/hilt/android/processor/internal/viewmodel/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,38 @@ kt_compiler_test(
],
)

kt_compiler_test(
name = "ViewModelValidationPluginWithAssistedInjectTest",
srcs = [
"ViewModelValidationPluginWithAssistedInjectTest.kt",
],
compiler_deps = [
"@androidsdk//:platforms/android-32/android.jar",
"@maven//:androidx_lifecycle_lifecycle_viewmodel",
"@maven//:androidx_lifecycle_lifecycle_viewmodel_savedstate",
"//third_party/java/compile_testing",
"//third_party/java/truth",
"//java/dagger/hilt/android/lifecycle:hilt_view_model",
"//java/dagger/hilt/android:android_entry_point",
"//java/dagger/hilt/android:hilt_android_app",
],
resources = glob(["goldens/*"]),
deps = [
":test_utils",
"//:compiler_internals",
"//java/dagger/hilt/android/processor/internal/viewmodel:processor_lib",
"//java/dagger/hilt/android/processor/internal/viewmodel:validation_plugin_lib",
"//java/dagger/hilt/android/testing/compile",
"//java/dagger/internal/codegen/xprocessing",
"//java/dagger/internal/codegen/xprocessing:xprocessing-testing",
"//java/dagger/testing/golden",
"//third_party/java/compile_testing",
"//third_party/java/guava/collect",
"//third_party/java/junit",
"//third_party/java/truth",
],
)

kt_jvm_library(
name = "test_utils",
srcs = [
Expand Down
Loading

0 comments on commit 84c034f

Please sign in to comment.