-
Notifications
You must be signed in to change notification settings - Fork 235
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
17 changed files
with
313 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 3 additions & 0 deletions
3
lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,15 @@ | ||
#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_ATTRS_H | ||
#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_ATTRS_H | ||
|
||
#include "pcg/layer_attrs.dtg.h" | ||
#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" | ||
|
||
namespace FlexFlow { | ||
|
||
OperatorType get_op_type(ParallelLayerAttrs const &); | ||
|
||
ParallelLayerAttrs parallel_layer_attrs_from_layer_attrs(LayerAttrs const &); | ||
|
||
} // namespace FlexFlow | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_CG_TO_PCG_H | ||
#define _FLEXFLOW_PCG_INCLUDE_PCG_CG_TO_PCG_H | ||
|
||
#include "pcg/computation_graph.dtg.h" | ||
#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" | ||
|
||
namespace FlexFlow { | ||
|
||
ParallelComputationGraph pcg_from_computation_graph(ComputationGraph const &cg); | ||
|
||
} // namespace FlexFlow | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
8 changes: 8 additions & 0 deletions
8
lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,18 @@ | ||
#include "pcg/parallel_computation_graph/parallel_layer_attrs.h" | ||
#include "op-attrs/pcg_operator_attrs.h" | ||
#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" | ||
|
||
namespace FlexFlow { | ||
|
||
OperatorType get_op_type(ParallelLayerAttrs const &a) { | ||
return get_op_type(a.op_attrs); | ||
} | ||
|
||
ParallelLayerAttrs | ||
parallel_layer_attrs_from_layer_attrs(LayerAttrs const &layer_attrs) { | ||
return ParallelLayerAttrs{ | ||
pcg_op_attrs_from_compgraph_op_attrs(layer_attrs.attrs), | ||
layer_attrs.name}; | ||
} | ||
|
||
} // namespace FlexFlow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,22 @@ | ||
#include "pcg/parallel_tensor_attrs.h" | ||
#include "op-attrs/parallel_tensor_shape.h" | ||
#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" | ||
|
||
namespace FlexFlow { | ||
|
||
TensorAttrs get_piece_attrs(ParallelTensorAttrs const ¶llel_attrs) { | ||
return TensorAttrs{get_piece_shape(parallel_attrs.shape), | ||
parallel_attrs.initializer, | ||
parallel_attrs.sync_type, | ||
parallel_attrs.initializer, | ||
parallel_attrs.create_gradients}; | ||
} | ||
|
||
ParallelTensorAttrs | ||
parallel_tensor_attrs_from_tensor_attrs(TensorAttrs const &tensor_attrs) { | ||
return ParallelTensorAttrs{lift_to_parallel(tensor_attrs.shape), | ||
tensor_attrs.sync_type, | ||
tensor_attrs.initializer, | ||
tensor_attrs.create_gradients}; | ||
} | ||
|
||
} // namespace FlexFlow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#include "pcg/pcg_from_computation_graph.h" | ||
#include "op-attrs/pcg_operator_attrs.h" | ||
#include "pcg/computation_graph.dtg.h" | ||
#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" | ||
#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" | ||
#include "pcg/parallel_computation_graph/parallel_layer_attrs.h" | ||
#include "pcg/parallel_tensor_attrs.h" | ||
#include "pcg/tensor_attrs.dtg.h" | ||
#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" | ||
#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" | ||
#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_value_labels.h" | ||
#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" | ||
#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h" | ||
#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" | ||
|
||
namespace FlexFlow { | ||
|
||
ParallelComputationGraph | ||
pcg_from_computation_graph(ComputationGraph const &cg) { | ||
auto layer_map = [&](Node const &_, LayerAttrs const &layer) { | ||
return parallel_layer_attrs_from_layer_attrs(layer); | ||
}; | ||
auto tensor_map = [&](OpenDataflowValue const &_, TensorAttrs const &tensor) { | ||
return parallel_tensor_attrs_from_tensor_attrs(tensor); | ||
}; | ||
auto graph_view = rewrite_value_labels( | ||
rewrite_node_labels(cg.raw_graph, layer_map), tensor_map); | ||
return ParallelComputationGraph{ | ||
LabelledDataflowGraph<ParallelLayerAttrs, ParallelTensorAttrs>:: | ||
create_copy_of< | ||
UnorderedSetLabelledOpenDataflowGraph<ParallelLayerAttrs, | ||
ParallelTensorAttrs>>( | ||
graph_view)}; | ||
} | ||
|
||
} // namespace FlexFlow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
#include "pcg/pcg_from_computation_graph.h" | ||
#include "pcg/computation_graph.h" | ||
#include "pcg/parallel_computation_graph/parallel_computation_graph.h" | ||
#include "utils/containers/get_only.h" | ||
#include <doctest/doctest.h> | ||
|
||
using namespace ::FlexFlow; | ||
|
||
TEST_SUITE(FF_TEST_SUITE) { | ||
TEST_CASE("pcg_from_computation_graph") { | ||
std::string input_name = "input"; | ||
std::string dense_name = "dense"; | ||
std::string relu_name = "relu"; | ||
|
||
ComputationGraph cg = [&] { | ||
ComputationGraph cg = make_empty_computation_graph(); | ||
|
||
TensorShape input_shape = TensorShape{ | ||
TensorDims{FFOrdered<nonnegative_int>{10_n, 12_n}}, DataType::FLOAT}; | ||
TensorAttrs input_attrs = TensorAttrs{input_shape, | ||
/*sync_type=*/std::nullopt, | ||
/*initializer=*/std::nullopt, | ||
CreateGrad::YES}; | ||
LayerAttrs input_layer_attrs = | ||
LayerAttrs{ComputationGraphOpAttrs{InputAttrs{}}, input_name}; | ||
LayerAddedResult input_added = add_layer( | ||
cg, input_layer_attrs, /*inputs=*/{}, /*outputs=*/{input_attrs}); | ||
tensor_guid_t input_tensor = get_only(input_added.outputs); | ||
|
||
LinearAttrs linear_attrs = LinearAttrs{/*out_channels=*/8_n, | ||
/*use_bias=*/true, | ||
/*data_type=*/DataType::FLOAT, | ||
/*activation=*/Activation::RELU, | ||
/*regularizer=*/std::nullopt}; | ||
TensorShape dense_output_shape = TensorShape{ | ||
TensorDims{FFOrdered<nonnegative_int>{10_n, 8_n}}, DataType::FLOAT}; | ||
LayerAttrs dense_layer_attrs = | ||
LayerAttrs{ComputationGraphOpAttrs{linear_attrs}, dense_name}; | ||
LayerAddedResult dense_added = | ||
add_layer(cg, | ||
/*attrs=*/dense_layer_attrs, | ||
/*inputs=*/{input_tensor}, | ||
/*outputs=*/ | ||
{TensorAttrs{dense_output_shape, | ||
/*sync_type=*/std::nullopt, | ||
/*initializer=*/std::nullopt, | ||
CreateGrad::YES}}); | ||
tensor_guid_t dense_output = get_only(dense_added.outputs); | ||
|
||
ElementUnaryAttrs relu_attrs = | ||
ElementUnaryAttrs{OperatorType::RELU, /*scalar=*/std::nullopt}; | ||
LayerAttrs relu_layer_attrs = | ||
LayerAttrs{ComputationGraphOpAttrs{relu_attrs}, relu_name}; | ||
add_layer(cg, | ||
/*attrs=*/relu_layer_attrs, | ||
/*inputs=*/{dense_output}, | ||
/*outputs=*/ | ||
{TensorAttrs{dense_output_shape, | ||
/*sync_type=*/std::nullopt, | ||
/*initializer=*/std::nullopt, | ||
CreateGrad::YES}}); | ||
|
||
return cg; | ||
}(); | ||
|
||
ParallelComputationGraph correct = [&] { | ||
ParallelComputationGraph pcg = empty_parallel_computation_graph(); | ||
|
||
ParallelTensorShape input_shape = ParallelTensorShape{ | ||
ParallelTensorDims{ | ||
FFOrdered<ShardParallelDim>{ShardParallelDim{10_n, 1_n}, | ||
ShardParallelDim{12_n, 1_n}}, | ||
ReplicaParallelDimSet{SumDegree{1_n}, DiscardCopyDegree{1_n}}}, | ||
DataType::FLOAT}; | ||
|
||
ParallelLayerAttrs input_layer_attrs = | ||
ParallelLayerAttrs{PCGOperatorAttrs{InputAttrs{}}, input_name}; | ||
|
||
ParallelLayerAddedResult input_added = add_parallel_layer( | ||
pcg, | ||
/*attrs=*/input_layer_attrs, | ||
/*inputs=*/{}, | ||
/*outputs=*/ | ||
{ParallelTensorAttrs{ | ||
input_shape, std::nullopt, std::nullopt, CreateGrad::YES}}); | ||
|
||
parallel_tensor_guid_t input_tensor = get_only(input_added.outputs); | ||
|
||
LinearAttrs linear_attrs = LinearAttrs{/*out_channels=*/8_n, | ||
/*use_bias=*/true, | ||
/*data_type=*/DataType::FLOAT, | ||
/*activation=*/Activation::RELU, | ||
/*regularizer=*/std::nullopt}; | ||
|
||
ParallelLayerAttrs dense_layer_attrs = | ||
ParallelLayerAttrs{PCGOperatorAttrs{linear_attrs}, dense_name}; | ||
|
||
ParallelTensorShape dense_output_shape = ParallelTensorShape{ | ||
ParallelTensorDims{ | ||
FFOrdered<ShardParallelDim>{ShardParallelDim{10_n, 1_n}, | ||
ShardParallelDim{8_n, 1_n}}, | ||
ReplicaParallelDimSet{SumDegree{1_n}, DiscardCopyDegree{1_n}}}, | ||
DataType::FLOAT}; | ||
|
||
ParallelLayerAddedResult dense_added = | ||
add_parallel_layer(pcg, | ||
/*attrs=*/dense_layer_attrs, | ||
/*inputs=*/{input_tensor}, | ||
/*outputs=*/ | ||
{ParallelTensorAttrs{dense_output_shape, | ||
/*sync_type=*/std::nullopt, | ||
/*initializer=*/std::nullopt, | ||
CreateGrad::YES}}); | ||
|
||
parallel_tensor_guid_t dense_output = get_only(dense_added.outputs); | ||
|
||
ElementUnaryAttrs relu_attrs = | ||
ElementUnaryAttrs{OperatorType::RELU, std::nullopt}; | ||
ParallelLayerAttrs relu_layer_attrs = | ||
ParallelLayerAttrs{PCGOperatorAttrs{relu_attrs}, relu_name}; | ||
|
||
add_parallel_layer(pcg, | ||
/*attrs=*/relu_layer_attrs, | ||
/*inputs=*/{dense_output}, | ||
/*outputs=*/ | ||
{ParallelTensorAttrs{dense_output_shape, | ||
/*sync_type=*/std::nullopt, | ||
/*initializer=*/std::nullopt, | ||
CreateGrad::YES}}); | ||
return pcg; | ||
}(); | ||
|
||
ParallelComputationGraph result = pcg_from_computation_graph(cg); | ||
|
||
CHECK(pcgs_are_isomorphic(result, correct)); | ||
} | ||
} |
1 change: 1 addition & 0 deletions
1
lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
23 changes: 23 additions & 0 deletions
23
lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_value_labels.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_VALUE_LABELS_H | ||
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_VALUE_LABELS_H | ||
|
||
#include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" | ||
#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.h" | ||
|
||
namespace FlexFlow { | ||
|
||
template < | ||
typename NodeLabel, | ||
typename ValueLabel, | ||
typename F, | ||
typename NewValueLabel = | ||
std::invoke_result_t<F, OpenDataflowValue const &, ValueLabel const &>> | ||
LabelledDataflowGraphView<NodeLabel, NewValueLabel> rewrite_value_labels( | ||
LabelledDataflowGraphView<NodeLabel, ValueLabel> const &g, F f) { | ||
return rewrite_value_labels<NodeLabel, ValueLabel, F, NewValueLabel>( | ||
view_as_labelled_open_dataflow_graph(g), f); | ||
} | ||
|
||
} // namespace FlexFlow | ||
|
||
#endif |
14 changes: 14 additions & 0 deletions
14
lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,15 @@ | ||
#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" | ||
#include "utils/archetypes/value_type.h" | ||
|
||
namespace FlexFlow { | ||
|
||
using NodeLabel = value_type<0>; | ||
using ValueLabel = value_type<1>; | ||
using NewNodeLabel = value_type<2>; | ||
using F = std::function<NewNodeLabel(Node const &, NodeLabel const &)>; | ||
|
||
template LabelledDataflowGraphView<NewNodeLabel, ValueLabel> | ||
rewrite_node_labels( | ||
LabelledDataflowGraphView<NodeLabel, ValueLabel> const &, F); | ||
|
||
} // namespace FlexFlow |
16 changes: 16 additions & 0 deletions
16
lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/rewrite_value_labels.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_value_labels.h" | ||
#include "utils/archetypes/value_type.h" | ||
|
||
namespace FlexFlow { | ||
|
||
using NodeLabel = value_type<0>; | ||
using ValueLabel = value_type<1>; | ||
using NewValueLabel = value_type<2>; | ||
using F = | ||
std::function<NewValueLabel(OpenDataflowValue const &, ValueLabel const &)>; | ||
|
||
template LabelledDataflowGraphView<NodeLabel, NewValueLabel> | ||
rewrite_value_labels( | ||
LabelledDataflowGraphView<NodeLabel, ValueLabel> const &, F); | ||
|
||
} // namespace FlexFlow |
Oops, something went wrong.