From 90786adade22784a52856a0e8b545ec6710b47f6 Mon Sep 17 00:00:00 2001 From: Krystian Stasiowski Date: Wed, 30 Oct 2024 12:50:40 -0600 Subject: [PATCH] [Clang][Sema] Always use latest redeclaration of primary template (#114258) This patch fixes a couple of regressions introduced in #111852. Consider: ``` template struct A { template static constexpr bool f() requires U { return true; } }; template<> template constexpr bool A::f() requires U { return A::f(); } template<> template constexpr bool A::f() requires U { return true; } static_assert(A::f()); // crash here ``` This crashes because when collecting template arguments from the _first_ declaration of `A::f` for constraint checking, we don't add the template arguments from the enclosing class template specialization because there exists another redeclaration that is a member specialization. This also fixes the following example, which happens for a similar reason: ``` // input.cppm export module input; export template constexpr int f(); template struct A { template friend constexpr int f(); }; template struct A<0>; template constexpr int f() { return N; } ``` ``` // input.cpp import input; static_assert(f<1>() == 1); // error: static assertion failed ``` --- clang/include/clang/AST/DeclTemplate.h | 52 ++--------- clang/lib/AST/Decl.cpp | 10 +-- clang/lib/AST/DeclCXX.cpp | 4 +- clang/lib/AST/DeclTemplate.cpp | 56 +++++++++++- clang/lib/Sema/SemaDecl.cpp | 4 +- clang/lib/Sema/SemaInit.cpp | 2 +- clang/lib/Sema/SemaTemplateInstantiate.cpp | 14 +-- clang/test/AST/ast-dump-decl.cpp | 2 +- .../CXX/temp/temp.spec/temp.expl.spec/p7.cpp | 87 +++++++++++++++++++ 9 files changed, 165 insertions(+), 66 deletions(-) diff --git a/clang/include/clang/AST/DeclTemplate.h b/clang/include/clang/AST/DeclTemplate.h index a572e3380f1655..0ca3fd48e81cf4 100644 --- a/clang/include/clang/AST/DeclTemplate.h +++ b/clang/include/clang/AST/DeclTemplate.h @@ -857,16 +857,6 @@ class RedeclarableTemplateDecl : public TemplateDecl, /// \endcode bool isMemberSpecialization() const { return Common.getInt(); } - /// Determines whether any redeclaration of this template was - /// a specialization of a member template. - bool hasMemberSpecialization() const { - for (const auto *D : redecls()) { - if (D->isMemberSpecialization()) - return true; - } - return false; - } - /// Note that this member template is a specialization. void setMemberSpecialization() { assert(!isMemberSpecialization() && "already a member specialization"); @@ -1965,13 +1955,7 @@ class ClassTemplateSpecializationDecl : public CXXRecordDecl, /// specialization which was specialized by this. llvm::PointerUnion - getSpecializedTemplateOrPartial() const { - if (const auto *PartialSpec = - SpecializedTemplate.dyn_cast()) - return PartialSpec->PartialSpecialization; - - return SpecializedTemplate.get(); - } + getSpecializedTemplateOrPartial() const; /// Retrieve the set of template arguments that should be used /// to instantiate members of the class template or class template partial @@ -2208,17 +2192,6 @@ class ClassTemplatePartialSpecializationDecl return InstantiatedFromMember.getInt(); } - /// Determines whether any redeclaration of this this class template partial - /// specialization was a specialization of a member partial specialization. - bool hasMemberSpecialization() const { - for (const auto *D : redecls()) { - if (cast(D) - ->isMemberSpecialization()) - return true; - } - return false; - } - /// Note that this member template is a specialization. void setMemberSpecialization() { return InstantiatedFromMember.setInt(true); } @@ -2740,13 +2713,7 @@ class VarTemplateSpecializationDecl : public VarDecl, /// Retrieve the variable template or variable template partial /// specialization which was specialized by this. llvm::PointerUnion - getSpecializedTemplateOrPartial() const { - if (const auto *PartialSpec = - SpecializedTemplate.dyn_cast()) - return PartialSpec->PartialSpecialization; - - return SpecializedTemplate.get(); - } + getSpecializedTemplateOrPartial() const; /// Retrieve the set of template arguments that should be used /// to instantiate the initializer of the variable template or variable @@ -2980,18 +2947,6 @@ class VarTemplatePartialSpecializationDecl return InstantiatedFromMember.getInt(); } - /// Determines whether any redeclaration of this this variable template - /// partial specialization was a specialization of a member partial - /// specialization. - bool hasMemberSpecialization() const { - for (const auto *D : redecls()) { - if (cast(D) - ->isMemberSpecialization()) - return true; - } - return false; - } - /// Note that this member template is a specialization. void setMemberSpecialization() { return InstantiatedFromMember.setInt(true); } @@ -3164,6 +3119,9 @@ class VarTemplateDecl : public RedeclarableTemplateDecl { return makeSpecIterator(getSpecializations(), true); } + /// Merge \p Prev with our RedeclarableTemplateDecl::Common. + void mergePrevDecl(VarTemplateDecl *Prev); + // Implement isa/cast/dyncast support static bool classof(const Decl *D) { return classofKind(D->getKind()); } static bool classofKind(Kind K) { return K == VarTemplate; } diff --git a/clang/lib/AST/Decl.cpp b/clang/lib/AST/Decl.cpp index 86913763ef9ff5..cd173d17263792 100644 --- a/clang/lib/AST/Decl.cpp +++ b/clang/lib/AST/Decl.cpp @@ -2708,7 +2708,7 @@ VarDecl *VarDecl::getTemplateInstantiationPattern() const { if (isTemplateInstantiation(VDTemplSpec->getTemplateSpecializationKind())) { auto From = VDTemplSpec->getInstantiatedFrom(); if (auto *VTD = From.dyn_cast()) { - while (!VTD->hasMemberSpecialization()) { + while (!VTD->isMemberSpecialization()) { if (auto *NewVTD = VTD->getInstantiatedFromMemberTemplate()) VTD = NewVTD; else @@ -2718,7 +2718,7 @@ VarDecl *VarDecl::getTemplateInstantiationPattern() const { } if (auto *VTPSD = From.dyn_cast()) { - while (!VTPSD->hasMemberSpecialization()) { + while (!VTPSD->isMemberSpecialization()) { if (auto *NewVTPSD = VTPSD->getInstantiatedFromMember()) VTPSD = NewVTPSD; else @@ -2732,7 +2732,7 @@ VarDecl *VarDecl::getTemplateInstantiationPattern() const { // If this is the pattern of a variable template, find where it was // instantiated from. FIXME: Is this necessary? if (VarTemplateDecl *VTD = VD->getDescribedVarTemplate()) { - while (!VTD->hasMemberSpecialization()) { + while (!VTD->isMemberSpecialization()) { if (auto *NewVTD = VTD->getInstantiatedFromMemberTemplate()) VTD = NewVTD; else @@ -4153,7 +4153,7 @@ FunctionDecl::getTemplateInstantiationPattern(bool ForDefinition) const { if (FunctionTemplateDecl *Primary = getPrimaryTemplate()) { // If we hit a point where the user provided a specialization of this // template, we're done looking. - while (!ForDefinition || !Primary->hasMemberSpecialization()) { + while (!ForDefinition || !Primary->isMemberSpecialization()) { if (auto *NewPrimary = Primary->getInstantiatedFromMemberTemplate()) Primary = NewPrimary; else @@ -4170,7 +4170,7 @@ FunctionTemplateDecl *FunctionDecl::getPrimaryTemplate() const { if (FunctionTemplateSpecializationInfo *Info = TemplateOrSpecialization .dyn_cast()) { - return Info->getTemplate(); + return Info->getTemplate()->getMostRecentDecl(); } return nullptr; } diff --git a/clang/lib/AST/DeclCXX.cpp b/clang/lib/AST/DeclCXX.cpp index db0ea62a2323eb..1c92fd9e3ff067 100644 --- a/clang/lib/AST/DeclCXX.cpp +++ b/clang/lib/AST/DeclCXX.cpp @@ -2030,7 +2030,7 @@ const CXXRecordDecl *CXXRecordDecl::getTemplateInstantiationPattern() const { if (auto *TD = dyn_cast(this)) { auto From = TD->getInstantiatedFrom(); if (auto *CTD = From.dyn_cast()) { - while (!CTD->hasMemberSpecialization()) { + while (!CTD->isMemberSpecialization()) { if (auto *NewCTD = CTD->getInstantiatedFromMemberTemplate()) CTD = NewCTD; else @@ -2040,7 +2040,7 @@ const CXXRecordDecl *CXXRecordDecl::getTemplateInstantiationPattern() const { } if (auto *CTPSD = From.dyn_cast()) { - while (!CTPSD->hasMemberSpecialization()) { + while (!CTPSD->isMemberSpecialization()) { if (auto *NewCTPSD = CTPSD->getInstantiatedFromMemberTemplate()) CTPSD = NewCTPSD; else diff --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp index 755ec72f00bf77..1db02d0d04448c 100644 --- a/clang/lib/AST/DeclTemplate.cpp +++ b/clang/lib/AST/DeclTemplate.cpp @@ -993,7 +993,17 @@ ClassTemplateSpecializationDecl::getSpecializedTemplate() const { if (const auto *PartialSpec = SpecializedTemplate.dyn_cast()) return PartialSpec->PartialSpecialization->getSpecializedTemplate(); - return SpecializedTemplate.get(); + return SpecializedTemplate.get()->getMostRecentDecl(); +} + +llvm::PointerUnion +ClassTemplateSpecializationDecl::getSpecializedTemplateOrPartial() const { + if (const auto *PartialSpec = + SpecializedTemplate.dyn_cast()) + return PartialSpec->PartialSpecialization->getMostRecentDecl(); + + return SpecializedTemplate.get()->getMostRecentDecl(); } SourceRange @@ -1283,6 +1293,39 @@ VarTemplateDecl::newCommon(ASTContext &C) const { return CommonPtr; } +void VarTemplateDecl::mergePrevDecl(VarTemplateDecl *Prev) { + // If we haven't created a common pointer yet, then it can just be created + // with the usual method. + if (!getCommonPtrInternal()) + return; + + Common *ThisCommon = static_cast(getCommonPtrInternal()); + Common *PrevCommon = nullptr; + SmallVector PreviousDecls; + for (; Prev; Prev = Prev->getPreviousDecl()) { + if (CommonBase *C = Prev->getCommonPtrInternal()) { + PrevCommon = static_cast(C); + break; + } + PreviousDecls.push_back(Prev); + } + + // If the previous redecl chain hasn't created a common pointer yet, then just + // use this common pointer. + if (!PrevCommon) { + for (auto *D : PreviousDecls) + D->setCommonPtr(ThisCommon); + return; + } + + // Ensure we don't leak any important state. + assert(ThisCommon->Specializations.empty() && + ThisCommon->PartialSpecializations.empty() && + "Can't merge incompatible declarations!"); + + setCommonPtr(PrevCommon); +} + VarTemplateSpecializationDecl * VarTemplateDecl::findSpecialization(ArrayRef Args, void *&InsertPos) { @@ -1405,7 +1448,16 @@ VarTemplateDecl *VarTemplateSpecializationDecl::getSpecializedTemplate() const { if (const auto *PartialSpec = SpecializedTemplate.dyn_cast()) return PartialSpec->PartialSpecialization->getSpecializedTemplate(); - return SpecializedTemplate.get(); + return SpecializedTemplate.get()->getMostRecentDecl(); +} + +llvm::PointerUnion +VarTemplateSpecializationDecl::getSpecializedTemplateOrPartial() const { + if (const auto *PartialSpec = + SpecializedTemplate.dyn_cast()) + return PartialSpec->PartialSpecialization->getMostRecentDecl(); + + return SpecializedTemplate.get()->getMostRecentDecl(); } SourceRange VarTemplateSpecializationDecl::getSourceRange() const { diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp index f8e5f3c6d309d6..3e8b76e8dfd625 100644 --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -4696,8 +4696,10 @@ void Sema::MergeVarDecl(VarDecl *New, LookupResult &Previous) { // Keep a chain of previous declarations. New->setPreviousDecl(Old); - if (NewTemplate) + if (NewTemplate) { + NewTemplate->mergePrevDecl(OldTemplate); NewTemplate->setPreviousDecl(OldTemplate); + } // Inherit access appropriately. New->setAccess(Old->getAccess()); diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp index 573e90aced3eea..e2a59f63ccf589 100644 --- a/clang/lib/Sema/SemaInit.cpp +++ b/clang/lib/Sema/SemaInit.cpp @@ -9954,7 +9954,7 @@ QualType Sema::DeduceTemplateSpecializationFromInitializer( auto SynthesizeAggrGuide = [&](InitListExpr *ListInit) { auto *Pattern = Template; while (Pattern->getInstantiatedFromMemberTemplate()) { - if (Pattern->hasMemberSpecialization()) + if (Pattern->isMemberSpecialization()) break; Pattern = Pattern->getInstantiatedFromMemberTemplate(); } diff --git a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp index b63063813f1b56..de0ec0128905ff 100644 --- a/clang/lib/Sema/SemaTemplateInstantiate.cpp +++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp @@ -343,7 +343,7 @@ struct TemplateInstantiationArgumentCollecter // If this function was instantiated from a specialized member that is // a function template, we're done. assert(FD->getPrimaryTemplate() && "No function template?"); - if (FD->getPrimaryTemplate()->hasMemberSpecialization()) + if (FD->getPrimaryTemplate()->isMemberSpecialization()) return Done(); // If this function is a generic lambda specialization, we are done. @@ -442,11 +442,11 @@ struct TemplateInstantiationArgumentCollecter Specialized = CTSD->getSpecializedTemplateOrPartial(); if (auto *CTPSD = Specialized.dyn_cast()) { - if (CTPSD->hasMemberSpecialization()) + if (CTPSD->isMemberSpecialization()) return Done(); } else { auto *CTD = Specialized.get(); - if (CTD->hasMemberSpecialization()) + if (CTD->isMemberSpecialization()) return Done(); } return UseNextDecl(CTSD); @@ -478,11 +478,11 @@ struct TemplateInstantiationArgumentCollecter Specialized = VTSD->getSpecializedTemplateOrPartial(); if (auto *VTPSD = Specialized.dyn_cast()) { - if (VTPSD->hasMemberSpecialization()) + if (VTPSD->isMemberSpecialization()) return Done(); } else { auto *VTD = Specialized.get(); - if (VTD->hasMemberSpecialization()) + if (VTD->isMemberSpecialization()) return Done(); } return UseNextDecl(VTSD); @@ -4141,7 +4141,7 @@ getPatternForClassTemplateSpecialization( CXXRecordDecl *Pattern = nullptr; Specialized = ClassTemplateSpec->getSpecializedTemplateOrPartial(); if (auto *CTD = Specialized.dyn_cast()) { - while (!CTD->hasMemberSpecialization()) { + while (!CTD->isMemberSpecialization()) { if (auto *NewCTD = CTD->getInstantiatedFromMemberTemplate()) CTD = NewCTD; else @@ -4151,7 +4151,7 @@ getPatternForClassTemplateSpecialization( } else if (auto *CTPSD = Specialized .dyn_cast()) { - while (!CTPSD->hasMemberSpecialization()) { + while (!CTPSD->isMemberSpecialization()) { if (auto *NewCTPSD = CTPSD->getInstantiatedFromMemberTemplate()) CTPSD = NewCTPSD; else diff --git a/clang/test/AST/ast-dump-decl.cpp b/clang/test/AST/ast-dump-decl.cpp index e84241cee922f5..7b998f20944f49 100644 --- a/clang/test/AST/ast-dump-decl.cpp +++ b/clang/test/AST/ast-dump-decl.cpp @@ -530,7 +530,7 @@ namespace testCanonicalTemplate { // CHECK-NEXT: | `-ClassTemplateDecl 0x{{.+}} parent 0x{{.+}} col:40 friend_undeclared TestClassTemplate{{$}} // CHECK-NEXT: | |-TemplateTypeParmDecl 0x{{.+}} col:23 typename depth 1 index 0 T2{{$}} // CHECK-NEXT: | `-CXXRecordDecl 0x{{.+}} parent 0x{{.+}} col:40 class TestClassTemplate{{$}} - // CHECK-NEXT: `-ClassTemplateSpecializationDecl 0x{{.+}} line:[[@LINE-19]]:31 class TestClassTemplate definition implicit_instantiation{{$}} + // CHECK-NEXT: `-ClassTemplateSpecializationDecl 0x{{.+}} line:[[@LINE-19]]:31 class TestClassTemplate definition implicit_instantiation{{$}} // CHECK-NEXT: |-DefinitionData pass_in_registers empty aggregate standard_layout trivially_copyable pod trivial literal has_constexpr_non_copy_move_ctor can_const_default_init{{$}} // CHECK-NEXT: | |-DefaultConstructor exists trivial constexpr defaulted_is_constexpr{{$}} // CHECK-NEXT: | |-CopyConstructor simple trivial has_const_param implicit_has_const_param{{$}} diff --git a/clang/test/CXX/temp/temp.spec/temp.expl.spec/p7.cpp b/clang/test/CXX/temp/temp.spec/temp.expl.spec/p7.cpp index 87127366eb58a5..e7e4738032f647 100644 --- a/clang/test/CXX/temp/temp.spec/temp.expl.spec/p7.cpp +++ b/clang/test/CXX/temp/temp.spec/temp.expl.spec/p7.cpp @@ -177,6 +177,93 @@ namespace Defined { static_assert(A::B::y == 2); } // namespace Defined +namespace Constrained { + template + struct A { + template requires V + static constexpr int f(); // expected-note {{declared here}} + + template requires V + static const int x; // expected-note {{declared here}} + + template requires V + static const int x; // expected-note {{declared here}} + + template requires V + struct B; // expected-note {{template is declared here}} + + template requires V + struct B; // expected-note {{template is declared here}} + }; + + template<> + template requires V + constexpr int A::f() { + return A::f(); + } + + template<> + template requires V + constexpr int A::x = A::x; + + template<> + template requires V + constexpr int A::x = A::x; + + template<> + template requires V + struct A::B { + static constexpr int y = A::B::y; + }; + + template<> + template requires V + struct A::B { + static constexpr int y = A::B::y; + }; + + template<> + template requires V + constexpr int A::f() { + return 1; + } + + template<> + template requires V + constexpr int A::x = 1; + + template<> + template requires V + constexpr int A::x = 2; + + template<> + template requires V + struct A::B { + static constexpr int y = 1; + }; + + template<> + template requires V + struct A::B { + static constexpr int y = 2; + }; + + static_assert(A::f() == 0); // expected-error {{static assertion expression is not an integral constant expression}} + // expected-note@-1 {{undefined function 'f' cannot be used in a constant expression}} + static_assert(A::x == 0); // expected-error {{static assertion expression is not an integral constant expression}} + // expected-note@-1 {{initializer of 'x' is unknown}} + static_assert(A::x == 0); // expected-error {{static assertion expression is not an integral constant expression}} + // expected-note@-1 {{initializer of 'x' is unknown}} + static_assert(A::B::y == 0); // expected-error {{implicit instantiation of undefined template 'Constrained::A::B'}} + static_assert(A::B::y == 0); // expected-error {{implicit instantiation of undefined template 'Constrained::A::B'}} + + static_assert(A::f() == 1); + static_assert(A::x == 1); + static_assert(A::x == 2); + static_assert(A::B::y == 1); + static_assert(A::B::y == 2); +} // namespace Constrained + namespace Dependent { template struct A {