Skip to content

Commit

Permalink
Update type and query to improve reentrancy
Browse files Browse the repository at this point in the history
  • Loading branch information
konradweiss committed May 6, 2024
1 parent aa7692b commit e06c88f
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package de.fraunhofer.aisec.cpg.frontends.solidity

import com.google.common.graph.Graphs
import de.fraunhofer.aisec.cpg.TranslationContext
import de.fraunhofer.aisec.cpg.graph.Component
import de.fraunhofer.aisec.cpg.graph.Node
import de.fraunhofer.aisec.cpg.graph.allChildren
import de.fraunhofer.aisec.cpg.graph.declarations.FunctionDeclaration
import de.fraunhofer.aisec.cpg.graph.declarations.RecordDeclaration
import de.fraunhofer.aisec.cpg.graph.declarations.TranslationUnitDeclaration
import de.fraunhofer.aisec.cpg.graph.declarations.VariableDeclaration
import de.fraunhofer.aisec.cpg.graph.newConstructExpression
import de.fraunhofer.aisec.cpg.graph.scopes.GlobalScope
import de.fraunhofer.aisec.cpg.graph.scopes.ValueDeclarationScope
import de.fraunhofer.aisec.cpg.graph.statements.expressions.CallExpression
import de.fraunhofer.aisec.cpg.graph.statements.expressions.ConstructExpression
import de.fraunhofer.aisec.cpg.graph.types.ObjectType
import de.fraunhofer.aisec.cpg.helpers.SubgraphWalker
import de.fraunhofer.aisec.cpg.passes.*
import de.fraunhofer.aisec.cpg.passes.order.DependsOn
import de.fraunhofer.aisec.cpg.passes.order.ExecuteBefore

@ExecuteBefore(EvaluationOrderGraphPass::class)
@ExecuteBefore(TypeResolver::class)
class ConstructorResolutionPass(ctx: TranslationContext): TranslationUnitPass(ctx) {

override fun accept(result: TranslationUnitDeclaration) {
val all = SubgraphWalker.flattenAST(result)
val records = all.filterIsInstance<RecordDeclaration>()
val calls = all.filterIsInstance<CallExpression>()

calls.forEach {
val call = it
val record = records.filter { it.name.lastPartsMatch(call.name) }.firstOrNull()
record?.let {
call.type = ObjectType(record.name.toString(), listOf(), false, call.language) //TypeHandler.getAddressType(call.language,call.ctx!!)
}
}
}

override fun cleanup() {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import de.fraunhofer.aisec.cpg.graph.statements.Statement
import de.fraunhofer.aisec.cpg.graph.statements.expressions.BinaryOperator
import de.fraunhofer.aisec.cpg.graph.statements.expressions.DeclaredReferenceExpression
import de.fraunhofer.aisec.cpg.graph.statements.expressions.Expression
import de.fraunhofer.aisec.cpg.graph.statements.expressions.InitializerListExpression
import de.fraunhofer.aisec.cpg.graph.types.Type
import de.fraunhofer.aisec.cpg.graph.types.TypeParser
import de.fraunhofer.aisec.cpg.graph.types.UnknownType
Expand Down Expand Up @@ -305,6 +306,14 @@ class DeclarationHandler(lang: SolidityLanguageFrontend) : Handler<Declaration,
}else{
(frontend).let {
it.functionsWithModifiers[ctx] = method
ctx.returnParameters()?.parameterList()?.let {
frontend.functionsWithModifiersAndNamedRet[ctx] = it
}
}
}
method.body?.let {
if(ctx.returnParameters() != null){
method.body = addReturnVariablesToBlock(it, ctx.returnParameters().parameterList())
}
}

Expand All @@ -314,6 +323,57 @@ class DeclarationHandler(lang: SolidityLanguageFrontend) : Handler<Declaration,
return method
}

private fun addReturnVariablesToBlock(block: Statement, returnParams: SolidityParser.ParameterListContext) : CompoundStatement{
val rets = mutableListOf<DeclaredReferenceExpression>()
val retBlock:CompoundStatement

if(!(block is CompoundStatement)){
val cpStmt = newCompoundStatement(block.code)
cpStmt.location = block.location
cpStmt.statements = mutableListOf(block)
retBlock = cpStmt
}else retBlock = block

val declStatement = newDeclarationStatement(frontend.getCodeFromRawNode(returnParams))
declStatement.location = frontend.getLocationFromRawNode(returnParams)

frontend.scopeManager.enterScope(retBlock)
returnParams.parameter().toList().filter { it.identifier() != null }.forEach {
val type = frontend.typeHandler.handle(it.typeName())?:newUnknownType()
val name = it.identifier().text.trim()

val varDecl = newVariableDeclaration(name, type, frontend.getCodeFromRawNode(it), false)
varDecl.location = frontend.getLocationFromRawNode(it)

frontend.scopeManager.addDeclaration(varDecl)
declStatement.addToPropertyEdgeDeclaration(varDecl)

rets.add(newDeclaredReferenceExpression(name,
type,
frontend.getCodeFromRawNode(it)))
}
if(rets.size == returnParams.parameter().toList().size){
val retStmt = newReturnStatement(frontend.getCodeFromRawNode(returnParams), returnParams)
retStmt.location = frontend.getLocationFromRawNode(returnParams)
if(rets.size == 1){
retStmt.returnValue = rets.get(0)
}else{
val retLists = newInitializerListExpression(frontend.getCodeFromRawNode(ctx), ctx)
retLists.location = frontend.getLocationFromRawNode(ctx)
retLists.initializers = rets
retStmt.returnValue = retLists
}

val tmp = retBlock.statements.toMutableList()
tmp.add(0, declStatement)
tmp.add(retStmt)
retBlock.statements = tmp
}

frontend.scopeManager.leaveScope(retBlock)
return retBlock
}

private fun handleMissingFunctionDefinition(filename: String, unit: SolidityParser.SourceUnitContext): MethodDeclaration {

val record = frontend.scopeManager.currentRecord
Expand Down Expand Up @@ -393,13 +453,20 @@ class DeclarationHandler(lang: SolidityLanguageFrontend) : Handler<Declaration,
val modifier = frontend.modifierStack.pop()
frontend.currentIdentifierMapStack.push(frontend.modifierIdentifierMap[modifier])
method.body = expandBodyWithModifiers(modifier)

// To consider named returns
method.body?.let {
val body = it
frontend.functionsWithModifiersAndNamedRet[ctx]?.let {
method.body = addReturnVariablesToBlock(body, it)
}
}
SubgraphWalker.flattenAST(method.body).forEach { it.comment = it.comment?: "" + methodId }
methodId++

frontend.modifierStack.push(modifier)
frontend.currentIdentifierMapStack.pop()
frontend.scopeManager.leaveScope(method)

}

public fun expandBodyWithModifiers(modifierOrFunction: ParserRuleContext): Statement {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,26 @@ class ExpressionHandler(lang: SolidityLanguageFrontend) : Handler<Expression, Pa
if(op != null && op.text.trim() == ".") {
val base = this.handle(expressions.first())
base?.let {
return newMemberExpression(ctx.text.trim(), base, UnknownType.getUnknownType(language), operatorCode = op.text.trim(), code = frontend.getCodeFromRawNode(ctx), rawNode = ctx)
val name = ctx.identifier()?.text?.trim()?:ctx.text.trim()
val type = if(base.type.name.toString().equals("address") && TypeHandler.addressExtenderMembers.contains(name)){
TypeHandler.getAddressType(language,frontend.ctx)
}else{
TypeHandler.getPredefinedTypes(ctx.text.trim(), language, frontend.ctx)
}
val memberExpression = newMemberExpression(name, base, type, operatorCode = op.text.trim(), code = frontend.getCodeFromRawNode(ctx), rawNode = ctx)
if(TypeHandler.addressExtenderMembers.contains(name) && base.type.equals(UnknownType.getUnknownType(language))){
base.registerTypeListener(object : HasType.TypeListener{
override fun typeChanged(src: HasType, root: MutableList<HasType>, oldType: Type) {
if(src.type.name.toString().equals("address")){
memberExpression.type = TypeHandler.getAddressType(language,frontend.ctx)
}
}

override fun possibleSubTypesChanged(src: HasType, root: MutableList<HasType>) { }
}
)
}
return memberExpression
}

}
Expand Down Expand Up @@ -311,8 +330,10 @@ class ExpressionHandler(lang: SolidityLanguageFrontend) : Handler<Expression, Pa
replacingReferenceWithExpression = false
}

val type = TypeHandler.getPredefinedTypes(name, language, frontend.ctx)

val ref = newDeclaredReferenceExpression(name,
newUnknownType(),
type,
frontend.getCodeFromRawNode(ctx))
return ref
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import de.fraunhofer.aisec.cpg.graph.declarations.*
import de.fraunhofer.aisec.cpg.graph.types.Type
import de.fraunhofer.aisec.cpg.graph.types.TypeParser
import de.fraunhofer.aisec.cpg.passes.EvaluationOrderGraphPass
import de.fraunhofer.aisec.cpg.passes.order.RegisterExtraPass
import de.fraunhofer.aisec.cpg.passes.order.ReplacePass
import de.fraunhofer.aisec.cpg.sarif.PhysicalLocation
import de.fraunhofer.aisec.cpg.sarif.Region
Expand All @@ -32,6 +33,7 @@ import java.io.File
import java.io.FileInputStream
import java.util.*
@ReplacePass(EvaluationOrderGraphPass::class, SolidityLanguage::class, EOGExtensionPass::class)
@RegisterExtraPass(ConstructorResolutionPass::class)
class SolidityLanguageFrontend(language: Language<SolidityLanguageFrontend>, ctx: TranslationContext) : LanguageFrontend(language, ctx) {

private val logger = LoggerFactory.getLogger(SolidityLanguageFrontend.javaClass)
Expand All @@ -45,6 +47,7 @@ class SolidityLanguageFrontend(language: Language<SolidityLanguageFrontend>, ctx
var modifierStack: Stack<ParserRuleContext> = Stack<ParserRuleContext>()

val functionsWithModifiers: MutableMap<SolidityParser.FunctionDefinitionContext, MethodDeclaration> = mutableMapOf()
val functionsWithModifiersAndNamedRet: MutableMap<SolidityParser.FunctionDefinitionContext, SolidityParser.ParameterListContext> = mutableMapOf()
val modifierMap: MutableMap<ModifierDefinition, SolidityParser.ModifierDefinitionContext> = mutableMapOf()
val modifierIdentifierMap: MutableMap<SolidityParser.ModifierDefinitionContext, MutableMap<String, SolidityParser.ExpressionContext>> = mutableMapOf()
var currentIdentifierMapStack: Stack<MutableMap<String, SolidityParser.ExpressionContext>> = Stack<MutableMap<String, SolidityParser.ExpressionContext>>()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package de.fraunhofer.aisec.cpg.frontends.solidity

import de.fraunhofer.aisec.cpg.TranslationContext
import de.fraunhofer.aisec.cpg.frontends.Handler
import de.fraunhofer.aisec.cpg.frontends.Language
import de.fraunhofer.aisec.cpg.frontends.LanguageFrontend
import de.fraunhofer.aisec.cpg.graph.declarations.Declaration
import de.fraunhofer.aisec.cpg.graph.newUnknownType
import de.fraunhofer.aisec.cpg.graph.types.*
import org.antlr.v4.runtime.ParserRuleContext
import org.antlr.v4.runtime.tree.TerminalNode
import org.slf4j.LoggerFactory
import java.util.concurrent.ConcurrentHashMap

class TypeHandler(lang: SolidityLanguageFrontend) : Handler<Type, ParserRuleContext, SolidityLanguageFrontend>({ UnknownType.getUnknownType(language = lang.language) }, lang) {

Expand Down Expand Up @@ -85,12 +89,57 @@ class TypeHandler(lang: SolidityLanguageFrontend) : Handler<Type, ParserRuleCont
return TypeParser.createFrom(it.text.trim(), language, false, frontend.ctx)
}
if(ctx.getStart().equals("address")) {
return TypeParser.createFrom("address", language, false, frontend.ctx)
return getAddressType(language,frontend.ctx)
}


logger.warn("Empty type name could not be translated properly")

return newUnknownType()
}
companion object {
/** A map of [UnknownType] and their respective [Language]. */
private val addressTypes = ConcurrentHashMap<Language<*>?, Type>()

var memberTypeMap = mapOf<String,Type>()
var addressExtenderMembers = listOf("call", "value", "gas")
@JvmStatic
fun getPredefinedTypes(name: String, language: Language<out LanguageFrontend>?, transContext: TranslationContext): Type {
if(memberTypeMap.isEmpty()){
memberTypeMap = mapOf(
// "blockhash(uint blockNumber)" returns (bytes32)
// "blobhash(uint index)" returns (bytes32)
"block.basefee" to IntegerType("uint256", modifier = NumericType.Modifier.UNSIGNED, bitWidth = 256),
"block.blobbasefee" to IntegerType("uint256", modifier = NumericType.Modifier.UNSIGNED, bitWidth = 256),
"block.chainid" to IntegerType("uint256", modifier = NumericType.Modifier.UNSIGNED, bitWidth = 256),
"block.coinbase" to getAddressType(language,transContext),
"block.difficulty" to IntegerType("uint256", modifier = NumericType.Modifier.UNSIGNED, bitWidth = 256),
"block.gaslimit" to IntegerType("uint256", modifier = NumericType.Modifier.UNSIGNED, bitWidth = 256),
"block.number" to IntegerType("uint256", modifier = NumericType.Modifier.UNSIGNED, bitWidth = 256),
"block.prevrandao" to IntegerType("uint256", modifier = NumericType.Modifier.UNSIGNED, bitWidth = 256),
"block.timestamp" to IntegerType("uint256", modifier = NumericType.Modifier.UNSIGNED, bitWidth = 256),
// "gasleft()" returns (uint256)
"msg.data" to NumericType("bytes", modifier = NumericType.Modifier.UNSIGNED),
"msg.sender" to getAddressType(language,transContext),
"msg.sig" to NumericType("bytes4", modifier = NumericType.Modifier.UNSIGNED, bitWidth = 4 * 8),
"msg.value" to IntegerType("uint256", modifier = NumericType.Modifier.UNSIGNED, bitWidth = 256),
"tx.gasprice" to IntegerType("uint256", modifier = NumericType.Modifier.UNSIGNED, bitWidth = 256),
"tx.origin" to getAddressType(language,transContext))
}
return memberTypeMap.getOrDefault(name, UnknownType.getUnknownType(language))
}


/** Use this function to obtain an [UnknownType] for the particular [language]. */
@JvmStatic
fun getAddressType(language: Language<out LanguageFrontend>?, transContext: TranslationContext): Type {

return addressTypes.computeIfAbsent(language) {
val addressType = TypeParser.createFrom("address", language, false, transContext)
addressType.language = language
addressType
}
}
}

}
19 changes: 14 additions & 5 deletions cpg-solidity/src/main/resources/Reentrancy
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
match p=(base:MemberExpression)-[:BASE|CALLEE]-(c:CallExpression)-[e:EOG|INVOKES|RETURNS*]->(n)
where not exists {(c)<--(em:EmitStatement)}
and not exists{
()-[r:RETURNS]->()-[i:INVOKES]->()
where r in relationships(p) and apoc.coll.indexOf(relationships(p), r) + 1 = apoc.coll.indexOf(relationships(p), i)
}
and (
exists{
(n)-[d:DFG*#]->(:FieldDeclaration)
(n)-[d:DFG*#]->(field:FieldDeclaration)
where exists ((field)<-[:FIELDS]-(:RecordDeclaration)-[:AST*]->(c))
} or exists {
(n)-[d:DFG*#]->(bin:BinaryOperator)-[:LHS]->()-[:BASE|CALLEE|LHS|ARRAY_EXPRESSION*]->()<-[:DFG*#]-(:FieldDeclaration)
(n)-[d:DFG*#]->(bin:BinaryOperator)-[:LHS]->()-[:BASE|CALLEE|LHS|ARRAY_EXPRESSION*]->()<-[:DFG*#]-(field:FieldDeclaration)
where bin.operatorCode in ['=', '|=', '^=', '&=', '<<=','>>=','+=', '-=', '*=', '/=', '%=']
and exists ((field)<-[:FIELDS]-(:RecordDeclaration)-[:AST*]->(c))
}
or exists {
(n)-[d:DFG*#]->(bin:UnaryOperator)-[:INPUT|BASE|CALLEE|LHS|ARRAY_EXPRESSION]->()<-[:DFG*#]-(:FieldDeclaration)
(n)-[d:DFG*#]->(bin:UnaryOperator)-[:INPUT|BASE|CALLEE|LHS|ARRAY_EXPRESSION]->()<-[:DFG*#]-(field:FieldDeclaration)
where bin.operatorCode in ['++','--']
and exists ((field)<-[:FIELDS]-(:RecordDeclaration)-[:AST*]->(c))
}
)
and(not exists {()-[:DFG]->(b1)<-[:BASE|CALLEE*]-(c)}
or exists {
dflow=(s)-[:DFG*#]->(b2)<-[:BASE|CALLEE*]-(c)
where not exists (()-[:DFG]->(s)) and not 'Literal' in labels(s) and not exists((s)<-[:PARAMETERS]-(:ConstructorDeclaration)) and (not s.isInferred or s.localName in ['msg', 'tx'] )
dflow=(s)-[:DFG*#]->(b2)<-[:BASE]-(callee)<-[:CALLEE]-(c)
where
(exists((b2)-[:TYPE]->({name: "address"})) or exists((b2)-[:TYPE]->(:ObjectType)-[:RECORD_DECLARATION]->()))
and not exists (()-[:DFG]->(s)) and not 'Literal' in labels(s) and not exists((s)<-[:PARAMETERS]-(:ConstructorDeclaration)) and (not s.isInferred or s.localName in ['msg', 'tx'] )
and not exists{(sub)-[:DFG]->(array)-[:SUBSCRIPT_EXPRESSION]->(sub) where sub in nodes(dflow) and array in nodes(dflow)}
}) and (
exists{((d:DeclaredReferenceExpression)-[:DFG*#]->(base)) where d.code in ['msg.sender', 'tx.origin']}
Expand Down

0 comments on commit e06c88f

Please sign in to comment.