Skip to content

Commit

Permalink
frontend: Fix and simplify specializeGenericTypes
Browse files Browse the repository at this point in the history
Always place the specialized structs *right before* the original.

Signed-off-by: Vladimír Štill <[email protected]>
  • Loading branch information
vlstill committed Feb 19, 2025
1 parent f44f445 commit d1eb186
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 39 deletions.
34 changes: 12 additions & 22 deletions frontends/p4/specializeGenericTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ bool TypeSpecializationMap::same(const TypeSpecialization *spec,
}

void TypeSpecializationMap::add(const IR::Type_Specialized *t, const IR::Type_StructLike *decl,
const IR::Node *insertion, NameGenerator *nameGen) {
NameGenerator *nameGen) {
auto it = map.find(t);
if (it != map.end()) return;

Expand All @@ -51,10 +51,10 @@ void TypeSpecializationMap::add(const IR::Type_Specialized *t, const IR::Type_St

cstring name = nameGen->newName(decl->getName().string_view());
LOG3("Found to specialize: " << dbp(t) << "(" << t << ") with name " << name
<< " insert before " << dbp(insertion));
<< " insert before " << dbp(decl));
auto argTypes = new IR::Vector<IR::Type>();
for (auto a : *t->arguments) argTypes->push_back(typeMap->getType(a, true));
TypeSpecialization *s = new TypeSpecialization(name, t, decl, insertion, argTypes);
TypeSpecialization *s = new TypeSpecialization(name, t, decl, argTypes);
map.emplace(t, s);
}

Expand Down Expand Up @@ -97,8 +97,10 @@ Visitor::profile_t FindTypeSpecializations::init_apply(const IR::Node *node) {
}

void FindTypeSpecializations::postorder(const IR::Type_Specialized *type) {
auto baseType = specMap->typeMap->getTypeType(type->baseType, true);
auto st = baseType->to<IR::Type_StructLike>();
// Look for the declaration using the resolution context (not type map) to find it always
// at the program's top level. This way we can also use the declaration as insertion point.
const auto *baseType = getDeclaration(type->baseType->path);
const auto *st = baseType->to<IR::Type_StructLike>();
if (st == nullptr || st->typeParameters->size() == 0)
// nothing to specialize
return;
Expand All @@ -113,37 +115,25 @@ void FindTypeSpecializations::postorder(const IR::Type_Specialized *type) {
// specialized instances of G, e.g., G<bit<32>>.
return;
}
// Find location where the specialization is to be inserted.
// This can be before a Parser, Control, or a toplevel instance declaration
const IR::Node *insert = findContext<IR::P4Parser>();
if (!insert) insert = findContext<IR::Function>();
if (!insert) insert = findContext<IR::P4Control>();
if (!insert) insert = findContext<IR::Type_Declaration>();
if (!insert) insert = findContext<IR::Declaration_Constant>();
if (!insert) insert = findContext<IR::Declaration_Variable>();
if (!insert) insert = findContext<IR::Declaration_Instance>();
if (!insert) insert = findContext<IR::P4Action>();
CHECK_NULL(insert);
specMap->add(type, st, insert, &nameGen);
specMap->add(type, st, &nameGen);
}

///////////////////////////////////////////////////////////////////////////////////////

const IR::Node *CreateSpecializedTypes::postorder(IR::Type_Declaration *type) {
for (auto it : specMap->map) {
if (it.second->declaration->name == type->name) {
auto specialized = it.first;
for (const auto &[specialized, specialization] : specMap->map) {
if (specialization->declaration->name == type->name) {
auto genDecl = type->to<IR::IMayBeGenericType>();
TypeVariableSubstitution ts;
ts.setBindings(type, genDecl->getTypeParameters(), specialized->arguments);
TypeSubstitutionVisitor tsv(specMap->typeMap, &ts);
tsv.setCalledBy(this);
auto renamed = type->apply(tsv)->to<IR::Type_StructLike>()->clone();
cstring name = it.second->name;
cstring name = specialization->name;
auto empty = new IR::TypeParameters();
renamed->name = name;
renamed->typeParameters = empty;
it.second->replacement = postorder(renamed)->to<IR::Type_StructLike>();
specialization->replacement = renamed;
LOG3("CST Specializing " << dbp(type) << " with " << ts << " as " << dbp(renamed));
}
}
Expand Down
29 changes: 12 additions & 17 deletions frontends/p4/specializeGenericTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,29 @@ struct TypeSpecialization : public IHasDbPrint {
cstring name;
/// Type that is being specialized
const IR::Type_Specialized *specialized;
/// Declaration of specialized type, which will be replaced
/// Declaration of specialized type (in the program top-level), specialized type will be
/// inserted before it.
const IR::Type_Declaration *declaration;
/// New synthesized type (created later)
const IR::Type_StructLike *replacement;
/// Insertion point
const IR::Node *insertion;
/// Save here the canonical types of the type arguments of 'specialized'.
/// The typeMap will be cleared, so we cannot look them up later.
const IR::Vector<IR::Type> *argumentTypes;

TypeSpecialization(cstring name, const IR::Type_Specialized *specialized,
const IR::Type_Declaration *decl, const IR::Node *insertion,
const IR::Vector<IR::Type> *argTypes)
const IR::Type_Declaration *decl, const IR::Vector<IR::Type> *argTypes)
: name(name),
specialized(specialized),
declaration(decl),
replacement(nullptr),
insertion(insertion),
argumentTypes(argTypes) {
CHECK_NULL(specialized);
CHECK_NULL(decl);
CHECK_NULL(insertion);
CHECK_NULL(argTypes);
}
void dbprint(std::ostream &out) const override {
out << "Specializing:" << dbp(specialized) << " from " << dbp(declaration) << " as "
<< dbp(replacement) << " inserted at " << dbp(insertion);
<< dbp(replacement) << " inserted at " << dbp(declaration);
}
};

Expand All @@ -66,7 +62,7 @@ struct TypeSpecializationMap : public IHasDbPrint {
std::set<TypeSpecialization *> inserted;

void add(const IR::Type_Specialized *t, const IR::Type_StructLike *decl,
const IR::Node *insertion, NameGenerator *nameGen);
NameGenerator *nameGen);
TypeSpecialization *get(const IR::Type_Specialized *t) const;
bool same(const TypeSpecialization *left, const IR::Type_Specialized *right) const;
void dbprint(std::ostream &out) const override {
Expand All @@ -76,14 +72,14 @@ struct TypeSpecializationMap : public IHasDbPrint {
}
IR::Vector<IR::Node> *getSpecializations(const IR::Node *insertionPoint) {
IR::Vector<IR::Node> *result = nullptr;
for (auto s : map) {
if (inserted.find(s.second) != inserted.end()) continue;
if (s.second->insertion == insertionPoint) {
for (const auto &[_, specialization] : map) {
if (inserted.find(specialization) != inserted.end()) continue;
if (specialization->declaration == insertionPoint) {
if (result == nullptr) result = new IR::Vector<IR::Node>();
LOG2("Will insert " << dbp(s.second->replacement) << " before "
LOG2("Will insert " << dbp(specialization->replacement) << " before "
<< dbp(insertionPoint));
result->push_back(s.second->replacement);
inserted.emplace(s.second);
result->push_back(specialization->replacement);
inserted.emplace(specialization);
}
}
return result;
Expand All @@ -93,7 +89,7 @@ struct TypeSpecializationMap : public IHasDbPrint {
/**
* Find all generic type instantiations and their type arguments.
*/
class FindTypeSpecializations : public Inspector {
class FindTypeSpecializations : public Inspector, ResolutionContext {
TypeSpecializationMap *specMap;
MinimalNameGenerator nameGen;

Expand Down Expand Up @@ -122,7 +118,6 @@ class CreateSpecializedTypes : public Transform {

const IR::Node *insert(const IR::Node *before);
const IR::Node *postorder(IR::Type_Declaration *type) override;
const IR::Node *postorder(IR::Declaration *decl) override { return insert(decl); }
};

/**
Expand Down

0 comments on commit d1eb186

Please sign in to comment.