• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/graph.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/op_def.pb.h"
27 #include "tensorflow/tools/graph_transforms/transform_utils.h"
28 
29 namespace tensorflow {
30 namespace graph_transforms {
31 
32 constexpr char kPartitionedCallOpName[] = "PartitionedCall";
33 constexpr char kFunctionAttrName[] = "f";
34 
35 namespace {
GetFunctionByNameFromLibrary(const GraphDef & graph,absl::string_view function_name)36 absl::optional<FunctionDef> GetFunctionByNameFromLibrary(
37     const GraphDef& graph, absl::string_view function_name) {
38   for (const auto& fct : graph.library().function()) {
39     if (fct.signature().name() == function_name) {
40       return fct;
41     }
42   }
43   return {};
44 }
45 
NormalizeNodeDefInput(const std::string & input_name)46 std::string NormalizeNodeDefInput(const std::string& input_name) {
47   std::vector<std::string> name_parts =
48       absl::StrSplit(input_name, absl::ByChar(':'));
49   if (name_parts.size() > 2) {
50     return absl::StrCat(name_parts[0], ":", name_parts.back());
51   }
52   return input_name;
53 }
54 
55 }  // namespace
56 
InlinePartitionedCall(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)57 Status InlinePartitionedCall(const GraphDef& input_graph_def,
58                              const TransformFuncContext& context,
59                              GraphDef* output_graph_def) {
60   output_graph_def->Clear();
61   absl::flat_hash_map<std::string, std::string> remap_input;
62 
63   for (const NodeDef& node : input_graph_def.node()) {
64     if (node.op() == kPartitionedCallOpName) {
65       if (node.attr().count(kFunctionAttrName) == 0) {
66         return Status(
67             error::Code::NOT_FOUND,
68             "Node " + node.name() + " has no attribute: " + kFunctionAttrName);
69       }
70 
71       if (!node.attr().at(kFunctionAttrName).has_func()) {
72         return Status(error::Code::NOT_FOUND,
73                       "Cannot figure out function name");
74       }
75       const std::string function_name =
76           node.attr().at(kFunctionAttrName).func().name();
77       absl::optional<FunctionDef> function =
78           GetFunctionByNameFromLibrary(input_graph_def, function_name);
79       if (!function.has_value()) {
80         return Status(error::Code::NOT_FOUND,
81                       "function " + function_name + " Not found");
82       }
83 
84       const std::string prefix = node.name();
85 
86       const int kOutputArgumentCount =
87           function->signature().output_arg().size();
88       for (int k = 0; k < kOutputArgumentCount; ++k) {
89         const std::string function_arg_output_name =
90             function->ret().at(function->signature().output_arg()[k].name());
91         remap_input.insert_or_assign(
92             CanonicalInputName(absl::StrCat(node.name(), ":", k)),
93             absl::StrCat(prefix, "/",
94                          NormalizeNodeDefInput(function_arg_output_name)));
95       }
96 
97       const int kInputArgumentCount = function->signature().input_arg().size();
98       if (node.input().size() != kInputArgumentCount) {
99         return Status(error::Code::INVALID_ARGUMENT,
100                       "Called function  " + function_name +
101                           " has invalid input signature.");
102       }
103       absl::flat_hash_map<std::string, std::string> input_argument_map;
104       for (int k = 0; k < kInputArgumentCount; ++k) {
105         const std::string canonical_name =
106             CanonicalInputName(function->signature().input_arg()[k].name());
107         input_argument_map.insert_or_assign(canonical_name, node.input()[k]);
108       }
109 
110       for (const NodeDef& function_node : function->node_def()) {
111         NodeDef* new_node = output_graph_def->mutable_node()->Add();
112         *new_node = function_node;
113         new_node->set_name(absl::StrCat(prefix, "/", function_node.name()));
114         absl::c_transform(
115             *new_node->mutable_input(), new_node->mutable_input()->begin(),
116             [prefix, input_argument_map](const std::string& input_name) {
117               const std::string canonical_input_name =
118                   CanonicalInputName(input_name);
119               if (input_argument_map.find(canonical_input_name) !=
120                   input_argument_map.end()) {
121                 return input_argument_map.at(canonical_input_name);
122               }
123               return absl::StrCat(prefix, "/",
124                                   NormalizeNodeDefInput(input_name));
125             });
126       }
127     } else {
128       NodeDef* new_node = output_graph_def->mutable_node()->Add();
129       *new_node = node;
130     }
131   }
132 
133   // Remap PartitionCall outputs to correct nodes.
134   for (NodeDef& node : *output_graph_def->mutable_node()) {
135     absl::c_transform(
136         *node.mutable_input(), node.mutable_input()->begin(),
137         [remap_input](const std::string& input_name) {
138           const std::string canonical_input_name =
139               CanonicalInputName(input_name);
140           if (remap_input.find(canonical_input_name) != remap_input.end()) {
141             return remap_input.at(canonical_input_name);
142           }
143           return input_name;
144         });
145   }
146   return OkStatus();
147 }
148 
149 REGISTER_GRAPH_TRANSFORM("inline_partitionedcall", InlinePartitionedCall);
150 }  // namespace graph_transforms
151 }  // namespace tensorflow
152