Skip to content

Commit

Permalink
Reapply "[Clang][Sema] Always use latest redeclaration of primary tem…
Browse files Browse the repository at this point in the history
…plate" (llvm#114569)

This patch reapplies llvm#114258, fixing an infinite recursion bug in
`ASTImporter` that occurs when importing the primary template of a class
template specialization when the latest redeclaration of that template
is a friend declaration in the primary template.
  • Loading branch information
sdkrystian authored and NoumanAmir657 committed Nov 4, 2024
1 parent 8d0e660 commit 2e1b618
Show file tree
Hide file tree
Showing 12 changed files with 222 additions and 67 deletions.
52 changes: 5 additions & 47 deletions clang/include/clang/AST/DeclTemplate.h
Original file line number Diff line number Diff line change
Expand Up @@ -857,16 +857,6 @@ class RedeclarableTemplateDecl : public TemplateDecl,
/// \endcode
bool isMemberSpecialization() const { return Common.getInt(); }

/// Determines whether any redeclaration of this template was
/// a specialization of a member template.
bool hasMemberSpecialization() const {
for (const auto *D : redecls()) {
if (D->isMemberSpecialization())
return true;
}
return false;
}

/// Note that this member template is a specialization.
void setMemberSpecialization() {
assert(!isMemberSpecialization() && "already a member specialization");
Expand Down Expand Up @@ -1965,13 +1955,7 @@ class ClassTemplateSpecializationDecl : public CXXRecordDecl,
/// specialization which was specialized by this.
llvm::PointerUnion<ClassTemplateDecl *,
ClassTemplatePartialSpecializationDecl *>
getSpecializedTemplateOrPartial() const {
if (const auto *PartialSpec =
SpecializedTemplate.dyn_cast<SpecializedPartialSpecialization *>())
return PartialSpec->PartialSpecialization;

return SpecializedTemplate.get<ClassTemplateDecl*>();
}
getSpecializedTemplateOrPartial() const;

/// Retrieve the set of template arguments that should be used
/// to instantiate members of the class template or class template partial
Expand Down Expand Up @@ -2208,17 +2192,6 @@ class ClassTemplatePartialSpecializationDecl
return InstantiatedFromMember.getInt();
}

/// Determines whether any redeclaration of this this class template partial
/// specialization was a specialization of a member partial specialization.
bool hasMemberSpecialization() const {
for (const auto *D : redecls()) {
if (cast<ClassTemplatePartialSpecializationDecl>(D)
->isMemberSpecialization())
return true;
}
return false;
}

/// Note that this member template is a specialization.
void setMemberSpecialization() { return InstantiatedFromMember.setInt(true); }

Expand Down Expand Up @@ -2740,13 +2713,7 @@ class VarTemplateSpecializationDecl : public VarDecl,
/// Retrieve the variable template or variable template partial
/// specialization which was specialized by this.
llvm::PointerUnion<VarTemplateDecl *, VarTemplatePartialSpecializationDecl *>
getSpecializedTemplateOrPartial() const {
if (const auto *PartialSpec =
SpecializedTemplate.dyn_cast<SpecializedPartialSpecialization *>())
return PartialSpec->PartialSpecialization;

return SpecializedTemplate.get<VarTemplateDecl *>();
}
getSpecializedTemplateOrPartial() const;

/// Retrieve the set of template arguments that should be used
/// to instantiate the initializer of the variable template or variable
Expand Down Expand Up @@ -2980,18 +2947,6 @@ class VarTemplatePartialSpecializationDecl
return InstantiatedFromMember.getInt();
}

/// Determines whether any redeclaration of this this variable template
/// partial specialization was a specialization of a member partial
/// specialization.
bool hasMemberSpecialization() const {
for (const auto *D : redecls()) {
if (cast<VarTemplatePartialSpecializationDecl>(D)
->isMemberSpecialization())
return true;
}
return false;
}

/// Note that this member template is a specialization.
void setMemberSpecialization() { return InstantiatedFromMember.setInt(true); }

Expand Down Expand Up @@ -3164,6 +3119,9 @@ class VarTemplateDecl : public RedeclarableTemplateDecl {
return makeSpecIterator(getSpecializations(), true);
}

/// Merge \p Prev with our RedeclarableTemplateDecl::Common.
void mergePrevDecl(VarTemplateDecl *Prev);

// Implement isa/cast/dyncast support
static bool classof(const Decl *D) { return classofKind(D->getKind()); }
static bool classofKind(Kind K) { return K == VarTemplate; }
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/AST/ASTImporter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6190,7 +6190,8 @@ ExpectedDecl ASTNodeImporter::VisitClassTemplateDecl(ClassTemplateDecl *D) {
ExpectedDecl ASTNodeImporter::VisitClassTemplateSpecializationDecl(
ClassTemplateSpecializationDecl *D) {
ClassTemplateDecl *ClassTemplate;
if (Error Err = importInto(ClassTemplate, D->getSpecializedTemplate()))
if (Error Err = importInto(ClassTemplate,
D->getSpecializedTemplate()->getCanonicalDecl()))
return std::move(Err);

// Import the context of this declaration.
Expand Down
10 changes: 5 additions & 5 deletions clang/lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2708,7 +2708,7 @@ VarDecl *VarDecl::getTemplateInstantiationPattern() const {
if (isTemplateInstantiation(VDTemplSpec->getTemplateSpecializationKind())) {
auto From = VDTemplSpec->getInstantiatedFrom();
if (auto *VTD = From.dyn_cast<VarTemplateDecl *>()) {
while (!VTD->hasMemberSpecialization()) {
while (!VTD->isMemberSpecialization()) {
if (auto *NewVTD = VTD->getInstantiatedFromMemberTemplate())
VTD = NewVTD;
else
Expand All @@ -2718,7 +2718,7 @@ VarDecl *VarDecl::getTemplateInstantiationPattern() const {
}
if (auto *VTPSD =
From.dyn_cast<VarTemplatePartialSpecializationDecl *>()) {
while (!VTPSD->hasMemberSpecialization()) {
while (!VTPSD->isMemberSpecialization()) {
if (auto *NewVTPSD = VTPSD->getInstantiatedFromMember())
VTPSD = NewVTPSD;
else
Expand All @@ -2732,7 +2732,7 @@ VarDecl *VarDecl::getTemplateInstantiationPattern() const {
// If this is the pattern of a variable template, find where it was
// instantiated from. FIXME: Is this necessary?
if (VarTemplateDecl *VTD = VD->getDescribedVarTemplate()) {
while (!VTD->hasMemberSpecialization()) {
while (!VTD->isMemberSpecialization()) {
if (auto *NewVTD = VTD->getInstantiatedFromMemberTemplate())
VTD = NewVTD;
else
Expand Down Expand Up @@ -4153,7 +4153,7 @@ FunctionDecl::getTemplateInstantiationPattern(bool ForDefinition) const {
if (FunctionTemplateDecl *Primary = getPrimaryTemplate()) {
// If we hit a point where the user provided a specialization of this
// template, we're done looking.
while (!ForDefinition || !Primary->hasMemberSpecialization()) {
while (!ForDefinition || !Primary->isMemberSpecialization()) {
if (auto *NewPrimary = Primary->getInstantiatedFromMemberTemplate())
Primary = NewPrimary;
else
Expand All @@ -4170,7 +4170,7 @@ FunctionTemplateDecl *FunctionDecl::getPrimaryTemplate() const {
if (FunctionTemplateSpecializationInfo *Info
= TemplateOrSpecialization
.dyn_cast<FunctionTemplateSpecializationInfo*>()) {
return Info->getTemplate();
return Info->getTemplate()->getMostRecentDecl();
}
return nullptr;
}
Expand Down
4 changes: 2 additions & 2 deletions clang/lib/AST/DeclCXX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2030,7 +2030,7 @@ const CXXRecordDecl *CXXRecordDecl::getTemplateInstantiationPattern() const {
if (auto *TD = dyn_cast<ClassTemplateSpecializationDecl>(this)) {
auto From = TD->getInstantiatedFrom();
if (auto *CTD = From.dyn_cast<ClassTemplateDecl *>()) {
while (!CTD->hasMemberSpecialization()) {
while (!CTD->isMemberSpecialization()) {
if (auto *NewCTD = CTD->getInstantiatedFromMemberTemplate())
CTD = NewCTD;
else
Expand All @@ -2040,7 +2040,7 @@ const CXXRecordDecl *CXXRecordDecl::getTemplateInstantiationPattern() const {
}
if (auto *CTPSD =
From.dyn_cast<ClassTemplatePartialSpecializationDecl *>()) {
while (!CTPSD->hasMemberSpecialization()) {
while (!CTPSD->isMemberSpecialization()) {
if (auto *NewCTPSD = CTPSD->getInstantiatedFromMemberTemplate())
CTPSD = NewCTPSD;
else
Expand Down
56 changes: 54 additions & 2 deletions clang/lib/AST/DeclTemplate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,17 @@ ClassTemplateSpecializationDecl::getSpecializedTemplate() const {
if (const auto *PartialSpec =
SpecializedTemplate.dyn_cast<SpecializedPartialSpecialization*>())
return PartialSpec->PartialSpecialization->getSpecializedTemplate();
return SpecializedTemplate.get<ClassTemplateDecl*>();
return SpecializedTemplate.get<ClassTemplateDecl *>()->getMostRecentDecl();
}

llvm::PointerUnion<ClassTemplateDecl *,
ClassTemplatePartialSpecializationDecl *>
ClassTemplateSpecializationDecl::getSpecializedTemplateOrPartial() const {
if (const auto *PartialSpec =
SpecializedTemplate.dyn_cast<SpecializedPartialSpecialization *>())
return PartialSpec->PartialSpecialization->getMostRecentDecl();

return SpecializedTemplate.get<ClassTemplateDecl *>()->getMostRecentDecl();
}

SourceRange
Expand Down Expand Up @@ -1283,6 +1293,39 @@ VarTemplateDecl::newCommon(ASTContext &C) const {
return CommonPtr;
}

void VarTemplateDecl::mergePrevDecl(VarTemplateDecl *Prev) {
// If we haven't created a common pointer yet, then it can just be created
// with the usual method.
if (!getCommonPtrInternal())
return;

Common *ThisCommon = static_cast<Common *>(getCommonPtrInternal());
Common *PrevCommon = nullptr;
SmallVector<VarTemplateDecl *, 8> PreviousDecls;
for (; Prev; Prev = Prev->getPreviousDecl()) {
if (CommonBase *C = Prev->getCommonPtrInternal()) {
PrevCommon = static_cast<Common *>(C);
break;
}
PreviousDecls.push_back(Prev);
}

// If the previous redecl chain hasn't created a common pointer yet, then just
// use this common pointer.
if (!PrevCommon) {
for (auto *D : PreviousDecls)
D->setCommonPtr(ThisCommon);
return;
}

// Ensure we don't leak any important state.
assert(ThisCommon->Specializations.empty() &&
ThisCommon->PartialSpecializations.empty() &&
"Can't merge incompatible declarations!");

setCommonPtr(PrevCommon);
}

VarTemplateSpecializationDecl *
VarTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
void *&InsertPos) {
Expand Down Expand Up @@ -1405,7 +1448,16 @@ VarTemplateDecl *VarTemplateSpecializationDecl::getSpecializedTemplate() const {
if (const auto *PartialSpec =
SpecializedTemplate.dyn_cast<SpecializedPartialSpecialization *>())
return PartialSpec->PartialSpecialization->getSpecializedTemplate();
return SpecializedTemplate.get<VarTemplateDecl *>();
return SpecializedTemplate.get<VarTemplateDecl *>()->getMostRecentDecl();
}

llvm::PointerUnion<VarTemplateDecl *, VarTemplatePartialSpecializationDecl *>
VarTemplateSpecializationDecl::getSpecializedTemplateOrPartial() const {
if (const auto *PartialSpec =
SpecializedTemplate.dyn_cast<SpecializedPartialSpecialization *>())
return PartialSpec->PartialSpecialization->getMostRecentDecl();

return SpecializedTemplate.get<VarTemplateDecl *>()->getMostRecentDecl();
}

SourceRange VarTemplateSpecializationDecl::getSourceRange() const {
Expand Down
4 changes: 3 additions & 1 deletion clang/lib/Sema/SemaDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4694,8 +4694,10 @@ void Sema::MergeVarDecl(VarDecl *New, LookupResult &Previous) {

// Keep a chain of previous declarations.
New->setPreviousDecl(Old);
if (NewTemplate)
if (NewTemplate) {
NewTemplate->mergePrevDecl(OldTemplate);
NewTemplate->setPreviousDecl(OldTemplate);
}

// Inherit access appropriately.
New->setAccess(Old->getAccess());
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/Sema/SemaInit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9954,7 +9954,7 @@ QualType Sema::DeduceTemplateSpecializationFromInitializer(
auto SynthesizeAggrGuide = [&](InitListExpr *ListInit) {
auto *Pattern = Template;
while (Pattern->getInstantiatedFromMemberTemplate()) {
if (Pattern->hasMemberSpecialization())
if (Pattern->isMemberSpecialization())
break;
Pattern = Pattern->getInstantiatedFromMemberTemplate();
}
Expand Down
14 changes: 7 additions & 7 deletions clang/lib/Sema/SemaTemplateInstantiate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ struct TemplateInstantiationArgumentCollecter
// If this function was instantiated from a specialized member that is
// a function template, we're done.
assert(FD->getPrimaryTemplate() && "No function template?");
if (FD->getPrimaryTemplate()->hasMemberSpecialization())
if (FD->getPrimaryTemplate()->isMemberSpecialization())
return Done();

// If this function is a generic lambda specialization, we are done.
Expand Down Expand Up @@ -442,11 +442,11 @@ struct TemplateInstantiationArgumentCollecter
Specialized = CTSD->getSpecializedTemplateOrPartial();
if (auto *CTPSD =
Specialized.dyn_cast<ClassTemplatePartialSpecializationDecl *>()) {
if (CTPSD->hasMemberSpecialization())
if (CTPSD->isMemberSpecialization())
return Done();
} else {
auto *CTD = Specialized.get<ClassTemplateDecl *>();
if (CTD->hasMemberSpecialization())
if (CTD->isMemberSpecialization())
return Done();
}
return UseNextDecl(CTSD);
Expand Down Expand Up @@ -478,11 +478,11 @@ struct TemplateInstantiationArgumentCollecter
Specialized = VTSD->getSpecializedTemplateOrPartial();
if (auto *VTPSD =
Specialized.dyn_cast<VarTemplatePartialSpecializationDecl *>()) {
if (VTPSD->hasMemberSpecialization())
if (VTPSD->isMemberSpecialization())
return Done();
} else {
auto *VTD = Specialized.get<VarTemplateDecl *>();
if (VTD->hasMemberSpecialization())
if (VTD->isMemberSpecialization())
return Done();
}
return UseNextDecl(VTSD);
Expand Down Expand Up @@ -4141,7 +4141,7 @@ getPatternForClassTemplateSpecialization(
CXXRecordDecl *Pattern = nullptr;
Specialized = ClassTemplateSpec->getSpecializedTemplateOrPartial();
if (auto *CTD = Specialized.dyn_cast<ClassTemplateDecl *>()) {
while (!CTD->hasMemberSpecialization()) {
while (!CTD->isMemberSpecialization()) {
if (auto *NewCTD = CTD->getInstantiatedFromMemberTemplate())
CTD = NewCTD;
else
Expand All @@ -4151,7 +4151,7 @@ getPatternForClassTemplateSpecialization(
} else if (auto *CTPSD =
Specialized
.dyn_cast<ClassTemplatePartialSpecializationDecl *>()) {
while (!CTPSD->hasMemberSpecialization()) {
while (!CTPSD->isMemberSpecialization()) {
if (auto *NewCTPSD = CTPSD->getInstantiatedFromMemberTemplate())
CTPSD = NewCTPSD;
else
Expand Down
2 changes: 1 addition & 1 deletion clang/test/AST/ast-dump-decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ namespace testCanonicalTemplate {
// CHECK-NEXT: | `-ClassTemplateDecl 0x{{.+}} parent 0x{{.+}} <col:5, col:40> col:40 friend_undeclared TestClassTemplate{{$}}
// CHECK-NEXT: | |-TemplateTypeParmDecl 0x{{.+}} <col:14, col:23> col:23 typename depth 1 index 0 T2{{$}}
// CHECK-NEXT: | `-CXXRecordDecl 0x{{.+}} parent 0x{{.+}} <col:34, col:40> col:40 class TestClassTemplate{{$}}
// CHECK-NEXT: `-ClassTemplateSpecializationDecl 0x{{.+}} <line:[[@LINE-19]]:3, line:[[@LINE-17]]:3> line:[[@LINE-19]]:31 class TestClassTemplate definition implicit_instantiation{{$}}
// CHECK-NEXT: `-ClassTemplateSpecializationDecl 0x{{.+}} <col:5, col:40> line:[[@LINE-19]]:31 class TestClassTemplate definition implicit_instantiation{{$}}
// CHECK-NEXT: |-DefinitionData pass_in_registers empty aggregate standard_layout trivially_copyable pod trivial literal has_constexpr_non_copy_move_ctor can_const_default_init{{$}}
// CHECK-NEXT: | |-DefaultConstructor exists trivial constexpr defaulted_is_constexpr{{$}}
// CHECK-NEXT: | |-CopyConstructor simple trivial has_const_param implicit_has_const_param{{$}}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
namespace N0 {
template<typename T>
struct A {
template<typename U>
friend struct A;
};

template struct A<long>;
} // namespace N0

namespace N1 {
template<typename T>
struct A;

template<typename T>
struct A {
template<typename U>
friend struct A;
};

template struct A<long>;
} // namespace N1

namespace N2 {
template<typename T>
struct A {
template<typename U>
friend struct A;
};

template<typename T>
struct A;

template struct A<long>;
} // namespace N2

namespace N3 {
struct A {
template<typename T>
friend struct B;
};

template<typename T>
struct B { };

template struct B<long>;
} // namespace N3
8 changes: 8 additions & 0 deletions clang/test/ASTMerge/class-template-spec/test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: %clang_cc1 -emit-pch -o %t.1.ast %S/Inputs/class-template-spec.cpp
// RUN: %clang_cc1 -ast-merge %t.1.ast -fsyntax-only -verify %s
// expected-no-diagnostics

template struct N0::A<short>;
template struct N1::A<short>;
template struct N2::A<short>;
template struct N3::B<short>;
Loading

0 comments on commit 2e1b618

Please sign in to comment.