• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <unordered_map>
18 #include <unordered_set>
19 
20 #include "absl/strings/match.h"
21 #include "tensorflow/c/c_api_internal.h"
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/graph_to_functiondef.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/node_def_util.h"
27 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/lib/strings/base64.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 
33 using tensorflow::errors::InvalidArgument;
34 
35 namespace tensorflow {
36 namespace {
37 
ValidateNonRefOutput(const Node * node,int idx)38 Status ValidateNonRefOutput(const Node* node, int idx) {
39   const DataType& dt = node->output_type(idx);
40   return IsRefType(dt)
41              ? InvalidArgument("Output ", idx, " of node '", node->name(),
42                                "' has a reference type ", DataTypeString(dt))
43              : Status::OK();
44 }
45 
46 // Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and
47 // does various checks while doing so. `input_nodes` will contain the same
48 // information as input_tensors just in a different structure to make
49 // following processing easier. TODO(iga): Simplify this nested structure.
ProcessInputs(const TF_Graph * fn_body,const char * fn_name,int ninputs,const TF_Output * inputs,std::vector<OutputTensor> * input_tensors,std::unordered_map<const Node *,std::vector<int>> * input_nodes)50 Status ProcessInputs(
51     const TF_Graph* fn_body, const char* fn_name, int ninputs,
52     const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
53     std::unordered_map<const Node*, std::vector<int>>* input_nodes)
54     EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
55   input_tensors->reserve(ninputs);
56   for (int i = 0; i < ninputs; ++i) {
57     Node* node = &inputs[i].oper->node;
58     int idx = inputs[i].index;
59 
60     TF_RETURN_WITH_CONTEXT_IF_ERROR(
61         fn_body->graph.IsValidOutputTensor(node, idx),
62         "Encountered while processing input ", i, " into function '", fn_name,
63         "'");
64     TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
65                                     "Encountered while processing input ", i,
66                                     " into function '", fn_name, "'");
67 
68     input_tensors->emplace_back(node, idx);
69 
70     const auto& iter = input_nodes->find(node);
71     if (iter == input_nodes->end()) {
72       input_nodes->insert({node, {idx}});
73     } else {
74       auto& indices = iter->second;
75       if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
76         return InvalidArgument("TF_Output ", node->name(), ":", idx,
77                                " appears more than once in the input list");
78       }
79       indices.push_back(idx);
80     }
81   }
82   return Status::OK();
83 }
84 
85 // Converts `noutputs` and `outputs` into `outputs_tensors` and does various
86 // checks while doing so.
ProcessOutputs(const TF_Graph * fn_body,const char * fn_name,int noutputs,const TF_Output * outputs,std::vector<OutputTensor> * output_tensors)87 Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
88                       int noutputs, const TF_Output* outputs,
89                       std::vector<OutputTensor>* output_tensors)
90     EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
91   output_tensors->reserve(noutputs);
92   for (int i = 0; i < noutputs; ++i) {
93     Node* node = &outputs[i].oper->node;
94     int idx = outputs[i].index;
95     TF_RETURN_WITH_CONTEXT_IF_ERROR(
96         fn_body->graph.IsValidOutputTensor(node, idx),
97         "Encountered while processing output ", i, " from function '", fn_name,
98         "'");
99     TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
100                                     "Encountered while creating function '",
101                                     fn_name, "'");
102     output_tensors->emplace_back(node, idx);
103   }
104   return Status::OK();
105 }
106 
107 // Populates `body_nodes` with the nodes that will become function's body.
108 // Performs various checks.
ComputeBodyNodes(const TF_Graph * fn_body,const char * fn_name,int num_opers,const TF_Operation * const * opers,const std::unordered_map<const Node *,std::vector<int>> & input_nodes,std::vector<const Node * > * body_nodes)109 Status ComputeBodyNodes(
110     const TF_Graph* fn_body, const char* fn_name, int num_opers,
111     const TF_Operation* const* opers,
112     const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
113     std::vector<const Node*>* body_nodes)
114     EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
115   if (num_opers == -1) {
116     for (const Node* node : fn_body->graph.op_nodes()) {
117       const auto& iter = input_nodes.find(node);
118       if (iter == input_nodes.end()) {
119         // This node is not referenced in inputs. Add it to the body.
120         body_nodes->push_back(node);
121       } else {
122         // This node is referenced in inputs. Currently, we place an
123         // artificial restriction and require that when num_opers=-1, such
124         // nodes must have a single output.
125         if (node->num_outputs() != 1) {
126           return InvalidArgument(
127               "When `num_opers` is set to -1, nodes referenced in `inputs` "
128               "must have a single output. Node ",
129               node->name(), " has ", node->num_outputs(),
130               " outputs. Encountered while creating function '", fn_name, "'");
131         }
132       }
133     }
134   } else {
135     body_nodes->reserve(num_opers);
136     for (int i = 0; i < num_opers; ++i) {
137       const Node* node = &opers[i]->node;
138       body_nodes->push_back(node);
139     }
140   }
141   return Status::OK();
142 }
143 
144 }  // namespace
145 }  // namespace tensorflow
146 
147 using tensorflow::Node;
148 using tensorflow::string;
149 
TF_GraphToFunctionWithControlOutputs(const TF_Graph * fn_body,const char * fn_name,unsigned char append_hash_to_fn_name,int num_opers,const TF_Operation * const * opers,int ninputs,const TF_Output * inputs,int noutputs,const TF_Output * outputs,const char * const * output_names,int ncontrol_outputs,const TF_Operation * const * control_outputs,const char * const * control_output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * status)150 TF_Function* TF_GraphToFunctionWithControlOutputs(
151     const TF_Graph* fn_body, const char* fn_name,
152     unsigned char append_hash_to_fn_name, int num_opers,
153     const TF_Operation* const* opers, int ninputs, const TF_Output* inputs,
154     int noutputs, const TF_Output* outputs, const char* const* output_names,
155     int ncontrol_outputs, const TF_Operation* const* control_outputs,
156     const char* const* control_output_names, const TF_FunctionOptions* opts,
157     const char* description, TF_Status* status) {
158   tensorflow::mutex_lock l(*const_cast<tensorflow::mutex*>(&fn_body->mu));
159 
160   // Process inputs.
161   std::vector<tensorflow::OutputTensor> input_tensors;
162   std::unordered_map<const Node*, std::vector<int>> input_nodes;
163   status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
164                                              &input_tensors, &input_nodes);
165   if (TF_GetCode(status) != TF_OK) return nullptr;
166 
167   // Process outputs.
168   std::vector<tensorflow::OutputTensor> output_tensors;
169   status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
170                                               outputs, &output_tensors);
171   if (TF_GetCode(status) != TF_OK) return nullptr;
172 
173   // Process output names.
174   std::vector<string> output_names_vec;
175   if (output_names) {
176     output_names_vec.reserve(noutputs);
177     for (int i = 0; i < noutputs; ++i) {
178       output_names_vec.push_back(string(output_names[i]));
179     }
180   }
181 
182   // Process control output names.
183   std::vector<string> control_output_names_vec;
184   if (control_output_names) {
185     control_output_names_vec.reserve(ncontrol_outputs);
186     for (int i = 0; i < ncontrol_outputs; ++i) {
187       control_output_names_vec.push_back(string(output_names[i]));
188     }
189   }
190 
191   // Compute body nodes.
192   std::vector<const Node*> body_nodes;
193   status->status = tensorflow::ComputeBodyNodes(
194       fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
195   if (TF_GetCode(status) != TF_OK) return nullptr;
196 
197   // Compute body nodes.
198   std::vector<const Node*> control_output_nodes;
199   for (int i = 0; i < ncontrol_outputs; ++i) {
200     control_output_nodes.push_back(&control_outputs[i]->node);
201   }
202 
203   // Do the actual function creation.
204   TF_Function* tf_function = new TF_Function();
205   DCHECK(append_hash_to_fn_name <= 1);
206   status->status = tensorflow::GraphToFunctionDef(
207       fn_body->graph, fn_name, append_hash_to_fn_name != 0,
208       /*set_stateful_from_nodes=*/true,
209       /*copy_placeholder_attrs_from_nodes=*/true, body_nodes, input_tensors,
210       output_tensors, output_names_vec, control_output_nodes,
211       control_output_names_vec, description, &tf_function->fdef);
212   if (TF_GetCode(status) != TF_OK) {
213     TF_DeleteFunction(tf_function);
214     return nullptr;
215   }
216   return tf_function;
217 }
218 
TF_GraphToFunction(const TF_Graph * fn_body,const char * fn_name,unsigned char append_hash_to_fn_name,int num_opers,const TF_Operation * const * opers,int ninputs,const TF_Output * inputs,int noutputs,const TF_Output * outputs,const char * const * output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * status)219 TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
220                                 unsigned char append_hash_to_fn_name,
221                                 int num_opers, const TF_Operation* const* opers,
222                                 int ninputs, const TF_Output* inputs,
223                                 int noutputs, const TF_Output* outputs,
224                                 const char* const* output_names,
225                                 const TF_FunctionOptions* opts,
226                                 const char* description, TF_Status* status) {
227   return TF_GraphToFunctionWithControlOutputs(
228       fn_body, fn_name, append_hash_to_fn_name, num_opers, opers, ninputs,
229       inputs, noutputs, outputs, output_names, 0, nullptr, nullptr, opts,
230       description, status);
231 }
232 
TF_FunctionName(TF_Function * func)233 const char* TF_FunctionName(TF_Function* func) {
234   return func->fdef.signature().name().c_str();
235 }
236 
TF_GraphCopyFunction(TF_Graph * g,const TF_Function * func,const TF_Function * grad,TF_Status * status)237 void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func,
238                           const TF_Function* grad, TF_Status* status) {
239   if (func == nullptr) {
240     status->status = InvalidArgument(
241         "'func' argument to TF_GraphCopyFunction cannot be null");
242     return;
243   }
244 
245   // TODO(iga): Add AddFunctionDef() and AddGradientDef() methods to graph
246   // to avoid the extra copy here.
247   tensorflow::FunctionDefLibrary fdef_lib;
248   *fdef_lib.add_function() = func->fdef;
249   if (grad) {
250     *fdef_lib.add_function() = grad->fdef;
251     tensorflow::GradientDef* gdef = fdef_lib.add_gradient();
252     gdef->set_function_name(func->fdef.signature().name());
253     gdef->set_gradient_func(grad->fdef.signature().name());
254   }
255 
256   tensorflow::mutex_lock l(g->mu);
257   status->status = g->graph.AddFunctionLibrary(fdef_lib);
258 }
259 
TF_GraphNumFunctions(TF_Graph * g)260 int TF_GraphNumFunctions(TF_Graph* g) {
261   tensorflow::mutex_lock l(g->mu);
262   return g->graph.flib_def().num_functions();
263 }
264 
TF_GraphGetFunctions(TF_Graph * g,TF_Function ** funcs,int max_func,TF_Status * status)265 int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func,
266                          TF_Status* status) {
267   tensorflow::FunctionDefLibrary lib;
268   {
269     tensorflow::mutex_lock l(g->mu);
270     lib = g->graph.flib_def().ToProto();
271   }
272   const auto len = std::min(max_func, static_cast<int>(lib.function_size()));
273   for (int i = 0; i < len; ++i) {
274     TF_Function* func = new TF_Function();
275     func->fdef = lib.function(i);
276     funcs[i] = func;
277   }
278   status->status = tensorflow::Status::OK();
279   return len;
280 }
281 
TF_FunctionToFunctionDef(TF_Function * func,TF_Buffer * output_func_def,TF_Status * status)282 void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
283                               TF_Status* status) {
284   status->status = MessageToBuffer(func->fdef, output_func_def);
285 }
286 
TF_FunctionImportFunctionDef(const void * proto,size_t proto_len,TF_Status * status)287 TF_Function* TF_FunctionImportFunctionDef(const void* proto, size_t proto_len,
288                                           TF_Status* status) {
289   TF_Function* func = new TF_Function();
290   if (!func->fdef.ParseFromArray(proto, proto_len)) {
291     status->status = InvalidArgument(
292         "Invalid FunctionDef given to TF_FunctionImportFunctionDef");
293     TF_DeleteFunction(func);
294     return nullptr;
295   }
296   status->status = tensorflow::Status::OK();
297   return func;
298 }
299 
TF_FunctionSetAttrValueProto(TF_Function * func,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)300 void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name,
301                                   const void* proto, size_t proto_len,
302                                   TF_Status* status) {
303   tensorflow::AttrValue attr_value;
304   if (!attr_value.ParseFromArray(proto, proto_len)) {
305     status->status = InvalidArgument(
306         "Unparseable AttrValue proto passed to "
307         "TF_FunctionSetAttrValueProto");
308     return;
309   }
310   (*func->fdef.mutable_attr())[string(attr_name)] = attr_value;
311   status->status = tensorflow::Status::OK();
312 }
313 
TF_FunctionGetAttrValueProto(TF_Function * func,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)314 void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name,
315                                   TF_Buffer* output_attr_value,
316                                   TF_Status* status) {
317   const auto& it = func->fdef.attr().find(attr_name);
318   if (it == func->fdef.attr().end()) {
319     status->status =
320         InvalidArgument("Function '", func->fdef.signature().name(),
321                         "' has no attr named '", attr_name, "'.");
322     return;
323   }
324   status->status = MessageToBuffer(it->second, output_attr_value);
325 }
326 
TF_DeleteFunction(TF_Function * func)327 void TF_DeleteFunction(TF_Function* func) { delete func; }
328