• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <unordered_map>
18 #include <unordered_set>
19 
20 #include "absl/strings/match.h"
21 #include "tensorflow/c/c_api_internal.h"
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/graph_to_functiondef.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/node_def_util.h"
27 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/platform/base64.h"
31 #include "tensorflow/core/platform/strcat.h"
32 
33 using tensorflow::errors::InvalidArgument;
34 
35 namespace tensorflow {
36 namespace {
37 
ValidateNonRefOutput(const Node * node,int idx)38 Status ValidateNonRefOutput(const Node* node, int idx) {
39   const DataType& dt = node->output_type(idx);
40   return IsRefType(dt)
41              ? InvalidArgument("Output ", idx, " of node '", node->name(),
42                                "' has a reference type ", DataTypeString(dt))
43              : Status::OK();
44 }
45 
46 // Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and
47 // does various checks while doing so. `input_nodes` will contain the same
48 // information as input_tensors just in a different structure to make
49 // following processing easier. TODO(iga): Simplify this nested structure.
ProcessInputs(const TF_Graph * fn_body,const char * fn_name,int ninputs,const TF_Output * inputs,std::vector<OutputTensor> * input_tensors,std::unordered_map<const Node *,std::vector<int>> * input_nodes)50 Status ProcessInputs(
51     const TF_Graph* fn_body, const char* fn_name, int ninputs,
52     const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
53     std::unordered_map<const Node*, std::vector<int>>* input_nodes)
54     TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
55   input_tensors->reserve(ninputs);
56   for (int i = 0; i < ninputs; ++i) {
57     Node* node = inputs[i].oper ? &inputs[i].oper->node : nullptr;
58     int idx = inputs[i].index;
59 
60     TF_RETURN_WITH_CONTEXT_IF_ERROR(
61         fn_body->graph.IsValidOutputTensor(node, idx),
62         "Encountered while processing input ", i, " into function '", fn_name,
63         "'");
64     TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
65                                     "Encountered while processing input ", i,
66                                     " into function '", fn_name, "'");
67 
68     input_tensors->emplace_back(node, idx);
69 
70     const auto& iter = input_nodes->find(node);
71     if (iter == input_nodes->end()) {
72       input_nodes->insert({node, {idx}});
73     } else {
74       auto& indices = iter->second;
75       if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
76         return InvalidArgument("TF_Output ", node->name(), ":", idx,
77                                " appears more than once in the input list");
78       }
79       indices.push_back(idx);
80     }
81   }
82   return Status::OK();
83 }
84 
85 // Converts `noutputs` and `outputs` into `outputs_tensors` and does various
86 // checks while doing so.
ProcessOutputs(const TF_Graph * fn_body,const char * fn_name,int noutputs,const TF_Output * outputs,std::vector<OutputTensor> * output_tensors)87 Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
88                       int noutputs, const TF_Output* outputs,
89                       std::vector<OutputTensor>* output_tensors)
90     TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
91   output_tensors->reserve(noutputs);
92   for (int i = 0; i < noutputs; ++i) {
93     Node* node = outputs[i].oper ? &outputs[i].oper->node : nullptr;
94     int idx = outputs[i].index;
95     TF_RETURN_WITH_CONTEXT_IF_ERROR(
96         fn_body->graph.IsValidOutputTensor(node, idx),
97         "Encountered while processing output ", i, " from function '", fn_name,
98         "'");
99     TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
100                                     "Encountered while creating function '",
101                                     fn_name, "'");
102     output_tensors->emplace_back(node, idx);
103   }
104   return Status::OK();
105 }
106 
107 // Populates `body_nodes` with the nodes that will become function's body.
108 // Performs various checks.
ComputeBodyNodes(const TF_Graph * fn_body,const char * fn_name,int num_opers,const TF_Operation * const * opers,const std::unordered_map<const Node *,std::vector<int>> & input_nodes,std::vector<const Node * > * body_nodes)109 Status ComputeBodyNodes(
110     const TF_Graph* fn_body, const char* fn_name, int num_opers,
111     const TF_Operation* const* opers,
112     const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
113     std::vector<const Node*>* body_nodes)
114     TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
115   if (num_opers == -1) {
116     for (const Node* node : fn_body->graph.op_nodes()) {
117       const auto& iter = input_nodes.find(node);
118       if (iter == input_nodes.end()) {
119         // This node is not referenced in inputs. Add it to the body.
120         body_nodes->push_back(node);
121       } else {
122         // This node is referenced in inputs. Currently, we place an
123         // artificial restriction and require that when num_opers=-1, such
124         // nodes must have a single output.
125         if (node->num_outputs() != 1) {
126           return InvalidArgument(
127               "When `num_opers` is set to -1, nodes referenced in `inputs` "
128               "must have a single output. Node ",
129               node->name(), " has ", node->num_outputs(),
130               " outputs. Encountered while creating function '", fn_name, "'");
131         }
132       }
133     }
134   } else {
135     body_nodes->reserve(num_opers);
136     for (int i = 0; i < num_opers; ++i) {
137       const Node* node = &opers[i]->node;
138       body_nodes->push_back(node);
139     }
140   }
141   return Status::OK();
142 }
143 
144 }  // namespace
145 }  // namespace tensorflow
146 
147 using tensorflow::Node;
148 using tensorflow::string;
149 
TF_GraphToFunctionWithControlOutputs(const TF_Graph * fn_body,const char * fn_name,unsigned char append_hash_to_fn_name,int num_opers,const TF_Operation * const * opers,int ninputs,const TF_Output * inputs,int noutputs,const TF_Output * outputs,const char * const * output_names,int ncontrol_outputs,const TF_Operation * const * control_outputs,const char * const * control_output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * status)150 TF_Function* TF_GraphToFunctionWithControlOutputs(
151     const TF_Graph* fn_body, const char* fn_name,
152     unsigned char append_hash_to_fn_name, int num_opers,
153     const TF_Operation* const* opers, int ninputs, const TF_Output* inputs,
154     int noutputs, const TF_Output* outputs, const char* const* output_names,
155     int ncontrol_outputs, const TF_Operation* const* control_outputs,
156     const char* const* control_output_names, const TF_FunctionOptions* opts,
157     const char* description, TF_Status* status) {
158   tensorflow::mutex_lock l(fn_body->mu);
159 
160   // Process inputs.
161   std::vector<tensorflow::OutputTensor> input_tensors;
162   std::unordered_map<const Node*, std::vector<int>> input_nodes;
163   status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
164                                              &input_tensors, &input_nodes);
165   if (TF_GetCode(status) != TF_OK) return nullptr;
166 
167   // Process outputs.
168   std::vector<tensorflow::OutputTensor> output_tensors;
169   status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
170                                               outputs, &output_tensors);
171   if (TF_GetCode(status) != TF_OK) return nullptr;
172 
173   // Process output names.
174   std::vector<string> output_names_vec;
175   if (output_names) {
176     output_names_vec.reserve(noutputs);
177     for (int i = 0; i < noutputs; ++i) {
178       output_names_vec.push_back(string(output_names[i]));
179     }
180   }
181 
182   // Process control output names.
183   std::vector<string> control_output_names_vec;
184   if (control_output_names) {
185     control_output_names_vec.reserve(ncontrol_outputs);
186     for (int i = 0; i < ncontrol_outputs; ++i) {
187       control_output_names_vec.push_back(string(output_names[i]));
188     }
189   }
190 
191   // Compute body nodes.
192   std::vector<const Node*> body_nodes;
193   status->status = tensorflow::ComputeBodyNodes(
194       fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
195   if (TF_GetCode(status) != TF_OK) return nullptr;
196 
197   // Compute body nodes.
198   std::vector<const Node*> control_output_nodes;
199   control_output_nodes.reserve(ncontrol_outputs);
200   for (int i = 0; i < ncontrol_outputs; ++i) {
201     control_output_nodes.push_back(&control_outputs[i]->node);
202   }
203 
204   // Do the actual function creation.
205   TF_Function* tf_function = new TF_Function();
206   DCHECK(append_hash_to_fn_name <= 1);
207   status->status = tensorflow::GraphToFunctionDef(
208       fn_body->graph, fn_name, append_hash_to_fn_name != 0,
209       /*set_stateful_from_nodes=*/true,
210       /*copy_placeholder_attrs_from_nodes=*/true, body_nodes, input_tensors,
211       output_tensors, output_names_vec, control_output_nodes,
212       control_output_names_vec, description, &tf_function->fdef);
213   if (TF_GetCode(status) != TF_OK) {
214     TF_DeleteFunction(tf_function);
215     return nullptr;
216   }
217 
218   for (const Node* n : fn_body->graph.nodes()) {
219     tf_function->stack_traces[n->name()] = n->GetStackTrace();
220   }
221 
222   return tf_function;
223 }
224 
TF_GraphToFunction(const TF_Graph * fn_body,const char * fn_name,unsigned char append_hash_to_fn_name,int num_opers,const TF_Operation * const * opers,int ninputs,const TF_Output * inputs,int noutputs,const TF_Output * outputs,const char * const * output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * status)225 TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
226                                 unsigned char append_hash_to_fn_name,
227                                 int num_opers, const TF_Operation* const* opers,
228                                 int ninputs, const TF_Output* inputs,
229                                 int noutputs, const TF_Output* outputs,
230                                 const char* const* output_names,
231                                 const TF_FunctionOptions* opts,
232                                 const char* description, TF_Status* status) {
233   return TF_GraphToFunctionWithControlOutputs(
234       fn_body, fn_name, append_hash_to_fn_name, num_opers, opers, ninputs,
235       inputs, noutputs, outputs, output_names, 0, nullptr, nullptr, opts,
236       description, status);
237 }
238 
TF_FunctionName(TF_Function * func)239 const char* TF_FunctionName(TF_Function* func) {
240   return func->fdef.signature().name().c_str();
241 }
242 
TF_GraphCopyFunction(TF_Graph * g,const TF_Function * func,const TF_Function * grad,TF_Status * status)243 void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func,
244                           const TF_Function* grad, TF_Status* status) {
245   if (func == nullptr) {
246     status->status = InvalidArgument(
247         "'func' argument to TF_GraphCopyFunction cannot be null");
248     return;
249   }
250 
251   // TODO(iga): Add AddFunctionDef() and AddGradientDef() methods to graph
252   // to avoid the extra copy here.
253   tensorflow::FunctionDefLibrary fdef_lib;
254   *fdef_lib.add_function() = func->fdef;
255   if (grad) {
256     *fdef_lib.add_function() = grad->fdef;
257     tensorflow::GradientDef* gdef = fdef_lib.add_gradient();
258     gdef->set_function_name(func->fdef.signature().name());
259     gdef->set_gradient_func(grad->fdef.signature().name());
260   }
261 
262   tensorflow::mutex_lock l(g->mu);
263   status->status = g->graph.AddFunctionLibrary(fdef_lib);
264 }
265 
TF_GraphNumFunctions(TF_Graph * g)266 int TF_GraphNumFunctions(TF_Graph* g) {
267   tensorflow::mutex_lock l(g->mu);
268   return g->graph.flib_def().num_functions();
269 }
270 
TF_GraphGetFunctions(TF_Graph * g,TF_Function ** funcs,int max_func,TF_Status * status)271 int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func,
272                          TF_Status* status) {
273   tensorflow::FunctionDefLibrary lib;
274   {
275     tensorflow::mutex_lock l(g->mu);
276     lib = g->graph.flib_def().ToProto();
277   }
278   const auto len = std::min(max_func, static_cast<int>(lib.function_size()));
279   for (int i = 0; i < len; ++i) {
280     TF_Function* func = new TF_Function();
281     func->fdef = lib.function(i);
282     funcs[i] = func;
283   }
284   status->status = tensorflow::Status::OK();
285   return len;
286 }
287 
TF_FunctionToFunctionDef(TF_Function * func,TF_Buffer * output_func_def,TF_Status * status)288 void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
289                               TF_Status* status) {
290   status->status = MessageToBuffer(func->fdef, output_func_def);
291 }
292 
TF_FunctionImportFunctionDef(const void * proto,size_t proto_len,TF_Status * status)293 TF_Function* TF_FunctionImportFunctionDef(const void* proto, size_t proto_len,
294                                           TF_Status* status) {
295   TF_Function* func = new TF_Function();
296   if (!func->fdef.ParseFromArray(proto, proto_len)) {
297     status->status = InvalidArgument(
298         "Invalid FunctionDef given to TF_FunctionImportFunctionDef");
299     TF_DeleteFunction(func);
300     return nullptr;
301   }
302   status->status = tensorflow::Status::OK();
303   return func;
304 }
305 
TF_FunctionSetAttrValueProto(TF_Function * func,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)306 void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name,
307                                   const void* proto, size_t proto_len,
308                                   TF_Status* status) {
309   tensorflow::AttrValue attr_value;
310   if (!attr_value.ParseFromArray(proto, proto_len)) {
311     status->status = InvalidArgument(
312         "Unparseable AttrValue proto passed to "
313         "TF_FunctionSetAttrValueProto");
314     return;
315   }
316   (*func->fdef.mutable_attr())[string(attr_name)] = attr_value;
317   status->status = tensorflow::Status::OK();
318 }
319 
TF_FunctionGetAttrValueProto(TF_Function * func,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)320 void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name,
321                                   TF_Buffer* output_attr_value,
322                                   TF_Status* status) {
323   const auto& it = func->fdef.attr().find(attr_name);
324   if (it == func->fdef.attr().end()) {
325     status->status =
326         InvalidArgument("Function '", func->fdef.signature().name(),
327                         "' has no attr named '", attr_name, "'.");
328     return;
329   }
330   status->status = MessageToBuffer(it->second, output_attr_value);
331 }
332 
TF_DeleteFunction(TF_Function * func)333 void TF_DeleteFunction(TF_Function* func) { delete func; }
334