diff --git a/svf-llvm/include/SVF-LLVM/LLVMUtil.h b/svf-llvm/include/SVF-LLVM/LLVMUtil.h index 924839b2f..5348a8f9b 100644 --- a/svf-llvm/include/SVF-LLVM/LLVMUtil.h +++ b/svf-llvm/include/SVF-LLVM/LLVMUtil.h @@ -496,7 +496,7 @@ const Argument* getConstructorThisPtr(const Function* fun); const Value* getVCallThisPtr(const CallBase* cs); const Value* getVCallVtblPtr(const CallBase* cs); s32_t getVCallIdx(const CallBase* cs); -std::string getClassNameFromType(const Type* ty); +std::string getClassNameFromType(const StructType* ty); std::string getClassNameOfThisPtr(const CallBase* cs); std::string getFunNameOfVCallSite(const CallBase* cs); bool VCallInCtorOrDtor(const CallBase* cs); @@ -513,13 +513,10 @@ bool VCallInCtorOrDtor(const CallBase* cs); bool isSameThisPtrInConstructor(const Argument* thisPtr1, const Value* thisPtr2); -template -std::string llvmToString(const T& val) -{ - std::string str; - llvm::raw_string_ostream(str) << val; - return str; -} +std::string dumpValue(const Value* val); + +std::string dumpType(const Type* type); + /** * See more: https://github.com/SVF-tools/SVF/pull/1191 diff --git a/svf-llvm/include/SVF-LLVM/SymbolTableBuilder.h b/svf-llvm/include/SVF-LLVM/SymbolTableBuilder.h index d8d72912c..e013da5b6 100644 --- a/svf-llvm/include/SVF-LLVM/SymbolTableBuilder.h +++ b/svf-llvm/include/SVF-LLVM/SymbolTableBuilder.h @@ -90,7 +90,7 @@ class SymbolTableBuilder /// Analyse types of all flattened fields of this object void analyzeObjType(ObjTypeInfo* typeinfo, const Value* val); /// Analyse types of heap and static objects - void analyzeHeapObjType(ObjTypeInfo* typeinfo, const Value* val); + u32_t analyzeHeapObjType(ObjTypeInfo* typeinfo, const Value* val); /// Analyse types of heap and static objects void analyzeStaticObjType(ObjTypeInfo* typeinfo, const Value* val); diff --git a/svf-llvm/lib/LLVMUtil.cpp b/svf-llvm/lib/LLVMUtil.cpp index 67e7d34e3..cfc331f5f 100644 --- a/svf-llvm/lib/LLVMUtil.cpp +++ b/svf-llvm/lib/LLVMUtil.cpp @@ -923,7 +923,12 @@ bool LLVMUtil::isLoadVtblInst(const LoadInst* loadInst) if (const FunctionType* functy = SVFUtil::dyn_cast(elemTy)) { const Type* paramty = functy->getParamType(0); - std::string className = LLVMUtil::getClassNameFromType(paramty); + std::string className = ""; + if(const PointerType* ptrTy = SVFUtil::dyn_cast(paramty)) + { + if(const StructType* st = SVFUtil::dyn_cast(getPtrElementType(ptrTy))) + className = LLVMUtil::getClassNameFromType(st); + } if (className.size() > 0) { return true; @@ -1156,25 +1161,19 @@ bool LLVMUtil::VCallInCtorOrDtor(const CallBase* cs) return false; } -std::string LLVMUtil::getClassNameFromType(const Type* ty) +std::string LLVMUtil::getClassNameFromType(const StructType* ty) { std::string className = ""; - if (const PointerType* ptrType = SVFUtil::dyn_cast(ty)) + if (!((SVFUtil::cast(ty))->isLiteral())) { - const Type* elemType = LLVMUtil::getPtrElementType(ptrType); - if (SVFUtil::isa(elemType) && - !((SVFUtil::cast(elemType))->isLiteral())) + std::string elemTypeName = ty->getStructName().str(); + if (elemTypeName.compare(0, clsName.size(), clsName) == 0) { - std::string elemTypeName = elemType->getStructName().str(); - if (elemTypeName.compare(0, clsName.size(), clsName) == 0) - { - className = elemTypeName.substr(clsName.size()); - } - else if (elemTypeName.compare(0, structName.size(), structName) == - 0) - { - className = elemTypeName.substr(structName.size()); - } + className = elemTypeName.substr(clsName.size()); + } + else if (elemTypeName.compare(0, structName.size(), structName) == 0) + { + className = elemTypeName.substr(structName.size()); } } return className; @@ -1182,7 +1181,7 @@ std::string LLVMUtil::getClassNameFromType(const Type* ty) std::string LLVMUtil::getClassNameOfThisPtr(const CallBase* inst) { - std::string thisPtrClassName; + std::string thisPtrClassName = ""; if (const MDNode* N = inst->getMetadata("VCallPtrType")) { const MDString* mdstr = SVFUtil::cast(N->getOperand(0).get()); @@ -1191,7 +1190,9 @@ std::string LLVMUtil::getClassNameOfThisPtr(const CallBase* inst) if (thisPtrClassName.size() == 0) { const Value* thisPtr = LLVMUtil::getVCallThisPtr(inst); - thisPtrClassName = getClassNameFromType(thisPtr->getType()); + if(const PointerType* ptrTy = SVFUtil::dyn_cast(thisPtr->getType())) + if(const StructType* st = SVFUtil::dyn_cast(getPtrElementType(ptrTy))) + thisPtrClassName = getClassNameFromType(st); } size_t found = thisPtrClassName.find_last_not_of("0123456789"); @@ -1278,6 +1279,29 @@ s64_t LLVMUtil::getCaseValue(const SwitchInst &switchInst, SuccBBAndCondValPair return val; } +std::string LLVMUtil::dumpValue(const Value* val) +{ + std::string str; + llvm::raw_string_ostream rawstr(str); + if (val) + rawstr << " " << *val << " "; + else + rawstr << " llvm Value is null"; + return rawstr.str(); +} + +std::string LLVMUtil::dumpType(const Type* type) +{ + std::string str; + llvm::raw_string_ostream rawstr(str); + if (type) + rawstr << " " << *type << " "; + else + rawstr << " llvm type is null"; + return rawstr.str(); +} + + namespace SVF { diff --git a/svf-llvm/lib/SymbolTableBuilder.cpp b/svf-llvm/lib/SymbolTableBuilder.cpp index a8b2b1cd7..7c0130bf4 100644 --- a/svf-llvm/lib/SymbolTableBuilder.cpp +++ b/svf-llvm/lib/SymbolTableBuilder.cpp @@ -753,20 +753,32 @@ u32_t SymbolTableBuilder::analyzeHeapAllocByteSize(const Value* val) /*! * Analyse types of heap and static objects */ -void SymbolTableBuilder::analyzeHeapObjType(ObjTypeInfo* typeinfo, const Value* val) +u32_t SymbolTableBuilder::analyzeHeapObjType(ObjTypeInfo* typeinfo, const Value* val) { if(const Value* castUse = getUniqueUseViaCastInst(val)) { typeinfo->setFlag(ObjTypeInfo::HEAP_OBJ); - typeinfo->resetTypeForHeapStaticObj( - LLVMModuleSet::getLLVMModuleSet()->getSVFType(castUse->getType())); + const Type* objTy = getTypeOfHeapAlloc(SVFUtil::cast(val)); + typeinfo->resetTypeForHeapStaticObj(LLVMModuleSet::getLLVMModuleSet()->getSVFType(objTy)); analyzeObjType(typeinfo,castUse); + if(SVFUtil::isa(objTy)) + return getNumOfElements(objTy); + else if(const StructType* st = SVFUtil::dyn_cast(objTy)) + { + /// For an C++ class, it can have variant elements depending on the vtable size, + /// Hence we only handle non-cpp-class object, the type of the cpp class is treated as PointerType at the cast site + if(getClassNameFromType(st).empty()) + return getNumOfElements(objTy); + else + typeinfo->resetTypeForHeapStaticObj(LLVMModuleSet::getLLVMModuleSet()->getSVFType(castUse->getType())); + } } else { typeinfo->setFlag(ObjTypeInfo::HEAP_OBJ); typeinfo->setFlag(ObjTypeInfo::HASPTR_OBJ); } + return typeinfo->getMaxFieldOffsetLimit(); } /*! @@ -777,8 +789,7 @@ void SymbolTableBuilder::analyzeStaticObjType(ObjTypeInfo* typeinfo, const Value if(const Value* castUse = getUniqueUseViaCastInst(val)) { typeinfo->setFlag(ObjTypeInfo::STATIC_OBJ); - typeinfo->resetTypeForHeapStaticObj( - LLVMModuleSet::getLLVMModuleSet()->getSVFType(castUse->getType())); + typeinfo->resetTypeForHeapStaticObj(LLVMModuleSet::getLLVMModuleSet()->getSVFType(castUse->getType())); analyzeObjType(typeinfo,castUse); } else @@ -844,9 +855,7 @@ void SymbolTableBuilder::initTypeInfo(ObjTypeInfo* typeinfo, const Value* val, LLVMModuleSet::getLLVMModuleSet()->getSVFInstruction( SVFUtil::cast(val)))) { - analyzeHeapObjType(typeinfo,val); - // Heap object, label its field as infinite here - elemNum = typeinfo->getMaxFieldOffsetLimit(); + elemNum = analyzeHeapObjType(typeinfo,val); // analyze heap alloc like (malloc/calloc/...), the alloc functions have // annotation like "AllocSize:Arg1". Please refer to extapi.c. // e.g. calloc(4, 10), annotation is "AllocSize:Arg0*Arg1", diff --git a/svf/lib/SVFIR/SymbolTableInfo.cpp b/svf/lib/SVFIR/SymbolTableInfo.cpp index 63394eaa2..186654a0a 100644 --- a/svf/lib/SVFIR/SymbolTableInfo.cpp +++ b/svf/lib/SVFIR/SymbolTableInfo.cpp @@ -373,7 +373,7 @@ bool ObjTypeInfo::isNonPtrFieldObj(const APOffset& apOffset) if (hasPtrObj() == false) return true; - const SVFType* ety = getType(); + const SVFType* ety = type; if (SVFUtil::isa(ety)) {