Skip to content

Commit

Permalink
Import node debugName (FX graph node name)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Aug 9, 2023
1 parent 82687dc commit 983ce1a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 2 deletions.
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ static LogicalResult rewriteMonomorphizedFuncClone(
auto newOp = OpBuilder(op).create<GlobalSlotGetOp>(
op.getLoc(), op.getType(),
objectGraphInfo.getGlobalSlotFor(affectedSlot).getSymName());
newOp->setAttr("FXOutputName", op->getAttr("FXOutputName"));
op.replaceAllUsesWith(&*newOp);
}
toErase.push_back(op);
Expand Down
6 changes: 4 additions & 2 deletions lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,10 @@ class InlineGlobalSlotsPass
getBackwardSliceIncludingRoot(initialValue);
IRMapping mapping;
OpBuilder builder(op);
for (Operation *opInSlice : slice)
builder.clone(*opInSlice, mapping);
for (Operation *opInSlice : slice) {
auto clonedOp = builder.clone(*opInSlice, mapping);
clonedOp->setAttr("FXOutputName", op->getAttr("FXOutputName"));
}
auto inlinedInitialValue = mapping.lookup(initialValue);
inlinedInitialValue = Torch::adjustStaticInformation(
builder, op.getLoc(), inlinedInitialValue, op.getType(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,29 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
MlirLocation loc = getMlirLocationFromNode(context, node);
auto kind = node->kind();


auto outs = node->outputs();
auto output = outs.size() == 1 ? outs[0] : nullptr;

auto addFxOutputNameAttr = [&](MlirOperation& operation) {
if (output && output->hasDebugName()) {
std::string name = output->debugName();
size_t len = name.size();
if (len > 2 && name[len-2] == '.' && name[len-1] == '1')
name = name.substr(0, len-2);
auto strAttr = mlirStringAttrGet(context, toMlirStringRef(name));
mlirOperationSetAttributeByName(operation, toMlirStringRef("FXOutputName"), strAttr);
}
};

auto createAndMapTrivialNode = [&](Node *node, const std::string &opName,
InputsTransformFn t) {
std::vector<MlirValue> mappedInputs = lookupMappedValues(node->inputs());
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, opName, loc,
getMlirTypesFromValues(loc, node->outputs(), importOptions),
t ? t(mappedInputs) : mappedInputs);
addFxOutputNameAttr(operation);
mapResults(node, operation);
};

Expand All @@ -102,6 +118,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
getMlirTypesFromValues(loc, node->outputs(), importOptions),
lookupMappedValues(node->inputs()),
toMlirNamedAttribute(attrName.c_str(), attr));
addFxOutputNameAttr(operation);
mapResults(node, operation);
};

Expand All @@ -112,6 +129,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
appendToBlock, loc, node->schema(),
getMlirTypesFromValues(loc, node->outputs(), importOptions),
lookupMappedValues(node->inputs()));
addFxOutputNameAttr(operation);
mapResults(node, operation);
return;
}
Expand Down

0 comments on commit 983ce1a

Please sign in to comment.