• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <unordered_map>
18 #include <unordered_set>
19 
20 #include "absl/strings/match.h"
21 #include "tensorflow/c/c_api_internal.h"
22 #include "tensorflow/c/tf_buffer_internal.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/function.pb.h"
25 #include "tensorflow/core/framework/graph_to_functiondef.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/graph/graph.h"
31 #include "tensorflow/core/platform/base64.h"
32 #include "tensorflow/core/platform/strcat.h"
33 
34 using tensorflow::errors::InvalidArgument;
35 
36 namespace tensorflow {
37 namespace {
38 
ValidateNonRefOutput(const Node * node,int idx)39 Status ValidateNonRefOutput(const Node* node, int idx) {
40   const DataType& dt = node->output_type(idx);
41   return IsRefType(dt)
42              ? InvalidArgument("Output ", idx, " of node '", node->name(),
43                                "' has a reference type ", DataTypeString(dt))
44              : OkStatus();
45 }
46 
47 // Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and
48 // does various checks while doing so. `input_nodes` will contain the same
49 // information as input_tensors just in a different structure to make
50 // following processing easier. TODO(iga): Simplify this nested structure.
ProcessInputs(const TF_Graph * fn_body,const char * fn_name,int ninputs,const TF_Output * inputs,std::vector<OutputTensor> * input_tensors,std::unordered_map<const Node *,std::vector<int>> * input_nodes)51 Status ProcessInputs(
52     const TF_Graph* fn_body, const char* fn_name, int ninputs,
53     const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
54     std::unordered_map<const Node*, std::vector<int>>* input_nodes)
55     TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
56   input_tensors->reserve(ninputs);
57   for (int i = 0; i < ninputs; ++i) {
58     Node* node = inputs[i].oper ? &inputs[i].oper->node : nullptr;
59     int idx = inputs[i].index;
60 
61     TF_RETURN_WITH_CONTEXT_IF_ERROR(
62         fn_body->graph.IsValidOutputTensor(node, idx),
63         "Encountered while processing input ", i, " into function '", fn_name,
64         "'");
65     TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
66                                     "Encountered while processing input ", i,
67                                     " into function '", fn_name, "'");
68 
69     input_tensors->emplace_back(node, idx);
70 
71     const auto& iter = input_nodes->find(node);
72     if (iter == input_nodes->end()) {
73       input_nodes->insert({node, {idx}});
74     } else {
75       auto& indices = iter->second;
76       if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
77         return InvalidArgument("TF_Output ", node->name(), ":", idx,
78                                " appears more than once in the input list");
79       }
80       indices.push_back(idx);
81     }
82   }
83   return OkStatus();
84 }
85 
86 // Converts `noutputs` and `outputs` into `outputs_tensors` and does various
87 // checks while doing so.
ProcessOutputs(const TF_Graph * fn_body,const char * fn_name,int noutputs,const TF_Output * outputs,std::vector<OutputTensor> * output_tensors)88 Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
89                       int noutputs, const TF_Output* outputs,
90                       std::vector<OutputTensor>* output_tensors)
91     TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
92   output_tensors->reserve(noutputs);
93   for (int i = 0; i < noutputs; ++i) {
94     Node* node = outputs[i].oper ? &outputs[i].oper->node : nullptr;
95     int idx = outputs[i].index;
96     TF_RETURN_WITH_CONTEXT_IF_ERROR(
97         fn_body->graph.IsValidOutputTensor(node, idx),
98         "Encountered while processing output ", i, " from function '", fn_name,
99         "'");
100     TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
101                                     "Encountered while creating function '",
102                                     fn_name, "'");
103     output_tensors->emplace_back(node, idx);
104   }
105   return OkStatus();
106 }
107 
108 // Populates `body_nodes` with the nodes that will become function's body.
109 // Performs various checks.
ComputeBodyNodes(const TF_Graph * fn_body,const char * fn_name,int num_opers,const TF_Operation * const * opers,const std::unordered_map<const Node *,std::vector<int>> & input_nodes,std::vector<const Node * > * body_nodes)110 Status ComputeBodyNodes(
111     const TF_Graph* fn_body, const char* fn_name, int num_opers,
112     const TF_Operation* const* opers,
113     const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
114     std::vector<const Node*>* body_nodes)
115     TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
116   if (num_opers == -1) {
117     for (const Node* node : fn_body->graph.op_nodes()) {
118       const auto& iter = input_nodes.find(node);
119       if (iter == input_nodes.end()) {
120         // This node is not referenced in inputs. Add it to the body.
121         body_nodes->push_back(node);
122       } else {
123         // This node is referenced in inputs. Currently, we place an
124         // artificial restriction and require that when num_opers=-1, such
125         // nodes must have a single output.
126         if (node->num_outputs() != 1) {
127           return InvalidArgument(
128               "When `num_opers` is set to -1, nodes referenced in `inputs` "
129               "must have a single output. Node ",
130               node->name(), " has ", node->num_outputs(),
131               " outputs. Encountered while creating function '", fn_name, "'");
132         }
133       }
134     }
135   } else {
136     body_nodes->reserve(num_opers);
137     for (int i = 0; i < num_opers; ++i) {
138       const Node* node = &opers[i]->node;
139       body_nodes->push_back(node);
140     }
141   }
142   return OkStatus();
143 }
144 
145 }  // namespace
146 }  // namespace tensorflow
147 
148 using tensorflow::Node;
149 using tensorflow::string;
150 
TF_GraphToFunctionWithControlOutputs(const TF_Graph * fn_body,const char * fn_name,unsigned char append_hash_to_fn_name,int num_opers,const TF_Operation * const * opers,int ninputs,const TF_Output * inputs,int noutputs,const TF_Output * outputs,const char * const * output_names,int ncontrol_outputs,const TF_Operation * const * control_outputs,const char * const * control_output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * status)151 TF_Function* TF_GraphToFunctionWithControlOutputs(
152     const TF_Graph* fn_body, const char* fn_name,
153     unsigned char append_hash_to_fn_name, int num_opers,
154     const TF_Operation* const* opers, int ninputs, const TF_Output* inputs,
155     int noutputs, const TF_Output* outputs, const char* const* output_names,
156     int ncontrol_outputs, const TF_Operation* const* control_outputs,
157     const char* const* control_output_names, const TF_FunctionOptions* opts,
158     const char* description, TF_Status* status) {
159   tensorflow::mutex_lock l(fn_body->mu);
160 
161   // Process inputs.
162   std::vector<tensorflow::OutputTensor> input_tensors;
163   std::unordered_map<const Node*, std::vector<int>> input_nodes;
164   status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
165                                              &input_tensors, &input_nodes);
166   if (TF_GetCode(status) != TF_OK) return nullptr;
167 
168   // Process outputs.
169   std::vector<tensorflow::OutputTensor> output_tensors;
170   status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
171                                               outputs, &output_tensors);
172   if (TF_GetCode(status) != TF_OK) return nullptr;
173 
174   // Process output names.
175   std::vector<string> output_names_vec;
176   if (output_names) {
177     output_names_vec.reserve(noutputs);
178     for (int i = 0; i < noutputs; ++i) {
179       output_names_vec.push_back(string(output_names[i]));
180     }
181   }
182 
183   // Process control output names.
184   std::vector<string> control_output_names_vec;
185   if (control_output_names) {
186     control_output_names_vec.reserve(ncontrol_outputs);
187     for (int i = 0; i < ncontrol_outputs; ++i) {
188       control_output_names_vec.push_back(string(output_names[i]));
189     }
190   }
191 
192   // Compute body nodes.
193   std::vector<const Node*> body_nodes;
194   status->status = tensorflow::ComputeBodyNodes(
195       fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
196   if (TF_GetCode(status) != TF_OK) return nullptr;
197 
198   // Compute body nodes.
199   std::vector<const Node*> control_output_nodes;
200   control_output_nodes.reserve(ncontrol_outputs);
201   for (int i = 0; i < ncontrol_outputs; ++i) {
202     control_output_nodes.push_back(&control_outputs[i]->node);
203   }
204 
205   // Do the actual function creation.
206   TF_Function* tf_function = new TF_Function();
207   DCHECK(append_hash_to_fn_name <= 1);
208   status->status = tensorflow::GraphToFunctionDef(
209       fn_body->graph, fn_name, append_hash_to_fn_name != 0,
210       /*set_stateful_from_nodes=*/true,
211       /*copy_placeholder_attrs_from_nodes=*/true, body_nodes, input_tensors,
212       output_tensors, output_names_vec, control_output_nodes,
213       control_output_names_vec, description, &tf_function->fdef);
214   if (TF_GetCode(status) != TF_OK) {
215     TF_DeleteFunction(tf_function);
216     return nullptr;
217   }
218 
219   for (const Node* n : fn_body->graph.nodes()) {
220     tf_function->stack_traces[n->name()] = n->GetStackTrace();
221   }
222 
223   return tf_function;
224 }
225 
TF_GraphToFunction(const TF_Graph * fn_body,const char * fn_name,unsigned char append_hash_to_fn_name,int num_opers,const TF_Operation * const * opers,int ninputs,const TF_Output * inputs,int noutputs,const TF_Output * outputs,const char * const * output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * status)226 TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
227                                 unsigned char append_hash_to_fn_name,
228                                 int num_opers, const TF_Operation* const* opers,
229                                 int ninputs, const TF_Output* inputs,
230                                 int noutputs, const TF_Output* outputs,
231                                 const char* const* output_names,
232                                 const TF_FunctionOptions* opts,
233                                 const char* description, TF_Status* status) {
234   return TF_GraphToFunctionWithControlOutputs(
235       fn_body, fn_name, append_hash_to_fn_name, num_opers, opers, ninputs,
236       inputs, noutputs, outputs, output_names, 0, nullptr, nullptr, opts,
237       description, status);
238 }
239 
TF_FunctionName(TF_Function * func)240 const char* TF_FunctionName(TF_Function* func) {
241   return func->fdef.signature().name().c_str();
242 }
243 
TF_GraphCopyFunction(TF_Graph * g,const TF_Function * func,const TF_Function * grad,TF_Status * status)244 void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func,
245                           const TF_Function* grad, TF_Status* status) {
246   if (func == nullptr) {
247     status->status = InvalidArgument(
248         "'func' argument to TF_GraphCopyFunction cannot be null");
249     return;
250   }
251 
252   // TODO(iga): Add AddFunctionDef() and AddGradientDef() methods to graph
253   // to avoid the extra copy here.
254   tensorflow::FunctionDefLibrary fdef_lib;
255   *fdef_lib.add_function() = func->fdef;
256   if (grad) {
257     *fdef_lib.add_function() = grad->fdef;
258     tensorflow::GradientDef* gdef = fdef_lib.add_gradient();
259     gdef->set_function_name(func->fdef.signature().name());
260     gdef->set_gradient_func(grad->fdef.signature().name());
261   }
262 
263   tensorflow::mutex_lock l(g->mu);
264   status->status = g->graph.AddFunctionLibrary(fdef_lib);
265 }
266 
TF_GraphNumFunctions(TF_Graph * g)267 int TF_GraphNumFunctions(TF_Graph* g) {
268   tensorflow::mutex_lock l(g->mu);
269   return g->graph.flib_def().num_functions();
270 }
271 
TF_GraphGetFunctions(TF_Graph * g,TF_Function ** funcs,int max_func,TF_Status * status)272 int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func,
273                          TF_Status* status) {
274   tensorflow::FunctionDefLibrary lib;
275   {
276     tensorflow::mutex_lock l(g->mu);
277     lib = g->graph.flib_def().ToProto();
278   }
279   const auto len = std::min(max_func, static_cast<int>(lib.function_size()));
280   for (int i = 0; i < len; ++i) {
281     TF_Function* func = new TF_Function();
282     func->fdef = lib.function(i);
283     funcs[i] = func;
284   }
285   status->status = ::tensorflow::OkStatus();
286   return len;
287 }
288 
TF_FunctionToFunctionDef(TF_Function * func,TF_Buffer * output_func_def,TF_Status * status)289 void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
290                               TF_Status* status) {
291   status->status = MessageToBuffer(func->fdef, output_func_def);
292 }
293 
TF_FunctionImportFunctionDef(const void * proto,size_t proto_len,TF_Status * status)294 TF_Function* TF_FunctionImportFunctionDef(const void* proto, size_t proto_len,
295                                           TF_Status* status) {
296   TF_Function* func = new TF_Function();
297   if (!func->fdef.ParseFromArray(proto, proto_len)) {
298     status->status = InvalidArgument(
299         "Invalid FunctionDef given to TF_FunctionImportFunctionDef");
300     TF_DeleteFunction(func);
301     return nullptr;
302   }
303   status->status = ::tensorflow::OkStatus();
304   return func;
305 }
306 
TF_FunctionSetAttrValueProto(TF_Function * func,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)307 void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name,
308                                   const void* proto, size_t proto_len,
309                                   TF_Status* status) {
310   tensorflow::AttrValue attr_value;
311   if (!attr_value.ParseFromArray(proto, proto_len)) {
312     status->status = InvalidArgument(
313         "Unparseable AttrValue proto passed to "
314         "TF_FunctionSetAttrValueProto");
315     return;
316   }
317   (*func->fdef.mutable_attr())[string(attr_name)] = attr_value;
318   status->status = ::tensorflow::OkStatus();
319 }
320 
TF_FunctionGetAttrValueProto(TF_Function * func,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)321 void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name,
322                                   TF_Buffer* output_attr_value,
323                                   TF_Status* status) {
324   const auto& it = func->fdef.attr().find(attr_name);
325   if (it == func->fdef.attr().end()) {
326     status->status =
327         InvalidArgument("Function '", func->fdef.signature().name(),
328                         "' has no attr named '", attr_name, "'.");
329     return;
330   }
331   status->status = MessageToBuffer(it->second, output_attr_value);
332 }
333 
TF_DeleteFunction(TF_Function * func)334 void TF_DeleteFunction(TF_Function* func) { delete func; }
335