1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <unordered_map>
18 #include <unordered_set>
19
20 #include "absl/strings/match.h"
21 #include "tensorflow/c/c_api_internal.h"
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/graph_to_functiondef.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/node_def_util.h"
27 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/platform/base64.h"
31 #include "tensorflow/core/platform/strcat.h"
32
33 using tensorflow::errors::InvalidArgument;
34
35 namespace tensorflow {
36 namespace {
37
ValidateNonRefOutput(const Node * node,int idx)38 Status ValidateNonRefOutput(const Node* node, int idx) {
39 const DataType& dt = node->output_type(idx);
40 return IsRefType(dt)
41 ? InvalidArgument("Output ", idx, " of node '", node->name(),
42 "' has a reference type ", DataTypeString(dt))
43 : Status::OK();
44 }
45
46 // Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and
47 // does various checks while doing so. `input_nodes` will contain the same
48 // information as input_tensors just in a different structure to make
49 // following processing easier. TODO(iga): Simplify this nested structure.
ProcessInputs(const TF_Graph * fn_body,const char * fn_name,int ninputs,const TF_Output * inputs,std::vector<OutputTensor> * input_tensors,std::unordered_map<const Node *,std::vector<int>> * input_nodes)50 Status ProcessInputs(
51 const TF_Graph* fn_body, const char* fn_name, int ninputs,
52 const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
53 std::unordered_map<const Node*, std::vector<int>>* input_nodes)
54 TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
55 input_tensors->reserve(ninputs);
56 for (int i = 0; i < ninputs; ++i) {
57 Node* node = inputs[i].oper ? &inputs[i].oper->node : nullptr;
58 int idx = inputs[i].index;
59
60 TF_RETURN_WITH_CONTEXT_IF_ERROR(
61 fn_body->graph.IsValidOutputTensor(node, idx),
62 "Encountered while processing input ", i, " into function '", fn_name,
63 "'");
64 TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
65 "Encountered while processing input ", i,
66 " into function '", fn_name, "'");
67
68 input_tensors->emplace_back(node, idx);
69
70 const auto& iter = input_nodes->find(node);
71 if (iter == input_nodes->end()) {
72 input_nodes->insert({node, {idx}});
73 } else {
74 auto& indices = iter->second;
75 if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
76 return InvalidArgument("TF_Output ", node->name(), ":", idx,
77 " appears more than once in the input list");
78 }
79 indices.push_back(idx);
80 }
81 }
82 return Status::OK();
83 }
84
85 // Converts `noutputs` and `outputs` into `outputs_tensors` and does various
86 // checks while doing so.
ProcessOutputs(const TF_Graph * fn_body,const char * fn_name,int noutputs,const TF_Output * outputs,std::vector<OutputTensor> * output_tensors)87 Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
88 int noutputs, const TF_Output* outputs,
89 std::vector<OutputTensor>* output_tensors)
90 TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
91 output_tensors->reserve(noutputs);
92 for (int i = 0; i < noutputs; ++i) {
93 Node* node = outputs[i].oper ? &outputs[i].oper->node : nullptr;
94 int idx = outputs[i].index;
95 TF_RETURN_WITH_CONTEXT_IF_ERROR(
96 fn_body->graph.IsValidOutputTensor(node, idx),
97 "Encountered while processing output ", i, " from function '", fn_name,
98 "'");
99 TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
100 "Encountered while creating function '",
101 fn_name, "'");
102 output_tensors->emplace_back(node, idx);
103 }
104 return Status::OK();
105 }
106
107 // Populates `body_nodes` with the nodes that will become function's body.
108 // Performs various checks.
ComputeBodyNodes(const TF_Graph * fn_body,const char * fn_name,int num_opers,const TF_Operation * const * opers,const std::unordered_map<const Node *,std::vector<int>> & input_nodes,std::vector<const Node * > * body_nodes)109 Status ComputeBodyNodes(
110 const TF_Graph* fn_body, const char* fn_name, int num_opers,
111 const TF_Operation* const* opers,
112 const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
113 std::vector<const Node*>* body_nodes)
114 TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
115 if (num_opers == -1) {
116 for (const Node* node : fn_body->graph.op_nodes()) {
117 const auto& iter = input_nodes.find(node);
118 if (iter == input_nodes.end()) {
119 // This node is not referenced in inputs. Add it to the body.
120 body_nodes->push_back(node);
121 } else {
122 // This node is referenced in inputs. Currently, we place an
123 // artificial restriction and require that when num_opers=-1, such
124 // nodes must have a single output.
125 if (node->num_outputs() != 1) {
126 return InvalidArgument(
127 "When `num_opers` is set to -1, nodes referenced in `inputs` "
128 "must have a single output. Node ",
129 node->name(), " has ", node->num_outputs(),
130 " outputs. Encountered while creating function '", fn_name, "'");
131 }
132 }
133 }
134 } else {
135 body_nodes->reserve(num_opers);
136 for (int i = 0; i < num_opers; ++i) {
137 const Node* node = &opers[i]->node;
138 body_nodes->push_back(node);
139 }
140 }
141 return Status::OK();
142 }
143
144 } // namespace
145 } // namespace tensorflow
146
147 using tensorflow::Node;
148 using tensorflow::string;
149
TF_GraphToFunctionWithControlOutputs(const TF_Graph * fn_body,const char * fn_name,unsigned char append_hash_to_fn_name,int num_opers,const TF_Operation * const * opers,int ninputs,const TF_Output * inputs,int noutputs,const TF_Output * outputs,const char * const * output_names,int ncontrol_outputs,const TF_Operation * const * control_outputs,const char * const * control_output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * status)150 TF_Function* TF_GraphToFunctionWithControlOutputs(
151 const TF_Graph* fn_body, const char* fn_name,
152 unsigned char append_hash_to_fn_name, int num_opers,
153 const TF_Operation* const* opers, int ninputs, const TF_Output* inputs,
154 int noutputs, const TF_Output* outputs, const char* const* output_names,
155 int ncontrol_outputs, const TF_Operation* const* control_outputs,
156 const char* const* control_output_names, const TF_FunctionOptions* opts,
157 const char* description, TF_Status* status) {
158 tensorflow::mutex_lock l(fn_body->mu);
159
160 // Process inputs.
161 std::vector<tensorflow::OutputTensor> input_tensors;
162 std::unordered_map<const Node*, std::vector<int>> input_nodes;
163 status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
164 &input_tensors, &input_nodes);
165 if (TF_GetCode(status) != TF_OK) return nullptr;
166
167 // Process outputs.
168 std::vector<tensorflow::OutputTensor> output_tensors;
169 status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
170 outputs, &output_tensors);
171 if (TF_GetCode(status) != TF_OK) return nullptr;
172
173 // Process output names.
174 std::vector<string> output_names_vec;
175 if (output_names) {
176 output_names_vec.reserve(noutputs);
177 for (int i = 0; i < noutputs; ++i) {
178 output_names_vec.push_back(string(output_names[i]));
179 }
180 }
181
182 // Process control output names.
183 std::vector<string> control_output_names_vec;
184 if (control_output_names) {
185 control_output_names_vec.reserve(ncontrol_outputs);
186 for (int i = 0; i < ncontrol_outputs; ++i) {
187 control_output_names_vec.push_back(string(output_names[i]));
188 }
189 }
190
191 // Compute body nodes.
192 std::vector<const Node*> body_nodes;
193 status->status = tensorflow::ComputeBodyNodes(
194 fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
195 if (TF_GetCode(status) != TF_OK) return nullptr;
196
197 // Compute body nodes.
198 std::vector<const Node*> control_output_nodes;
199 control_output_nodes.reserve(ncontrol_outputs);
200 for (int i = 0; i < ncontrol_outputs; ++i) {
201 control_output_nodes.push_back(&control_outputs[i]->node);
202 }
203
204 // Do the actual function creation.
205 TF_Function* tf_function = new TF_Function();
206 DCHECK(append_hash_to_fn_name <= 1);
207 status->status = tensorflow::GraphToFunctionDef(
208 fn_body->graph, fn_name, append_hash_to_fn_name != 0,
209 /*set_stateful_from_nodes=*/true,
210 /*copy_placeholder_attrs_from_nodes=*/true, body_nodes, input_tensors,
211 output_tensors, output_names_vec, control_output_nodes,
212 control_output_names_vec, description, &tf_function->fdef);
213 if (TF_GetCode(status) != TF_OK) {
214 TF_DeleteFunction(tf_function);
215 return nullptr;
216 }
217
218 for (const Node* n : fn_body->graph.nodes()) {
219 tf_function->stack_traces[n->name()] = n->GetStackTrace();
220 }
221
222 return tf_function;
223 }
224
TF_GraphToFunction(const TF_Graph * fn_body,const char * fn_name,unsigned char append_hash_to_fn_name,int num_opers,const TF_Operation * const * opers,int ninputs,const TF_Output * inputs,int noutputs,const TF_Output * outputs,const char * const * output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * status)225 TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
226 unsigned char append_hash_to_fn_name,
227 int num_opers, const TF_Operation* const* opers,
228 int ninputs, const TF_Output* inputs,
229 int noutputs, const TF_Output* outputs,
230 const char* const* output_names,
231 const TF_FunctionOptions* opts,
232 const char* description, TF_Status* status) {
233 return TF_GraphToFunctionWithControlOutputs(
234 fn_body, fn_name, append_hash_to_fn_name, num_opers, opers, ninputs,
235 inputs, noutputs, outputs, output_names, 0, nullptr, nullptr, opts,
236 description, status);
237 }
238
TF_FunctionName(TF_Function * func)239 const char* TF_FunctionName(TF_Function* func) {
240 return func->fdef.signature().name().c_str();
241 }
242
TF_GraphCopyFunction(TF_Graph * g,const TF_Function * func,const TF_Function * grad,TF_Status * status)243 void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func,
244 const TF_Function* grad, TF_Status* status) {
245 if (func == nullptr) {
246 status->status = InvalidArgument(
247 "'func' argument to TF_GraphCopyFunction cannot be null");
248 return;
249 }
250
251 // TODO(iga): Add AddFunctionDef() and AddGradientDef() methods to graph
252 // to avoid the extra copy here.
253 tensorflow::FunctionDefLibrary fdef_lib;
254 *fdef_lib.add_function() = func->fdef;
255 if (grad) {
256 *fdef_lib.add_function() = grad->fdef;
257 tensorflow::GradientDef* gdef = fdef_lib.add_gradient();
258 gdef->set_function_name(func->fdef.signature().name());
259 gdef->set_gradient_func(grad->fdef.signature().name());
260 }
261
262 tensorflow::mutex_lock l(g->mu);
263 status->status = g->graph.AddFunctionLibrary(fdef_lib);
264 }
265
TF_GraphNumFunctions(TF_Graph * g)266 int TF_GraphNumFunctions(TF_Graph* g) {
267 tensorflow::mutex_lock l(g->mu);
268 return g->graph.flib_def().num_functions();
269 }
270
TF_GraphGetFunctions(TF_Graph * g,TF_Function ** funcs,int max_func,TF_Status * status)271 int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func,
272 TF_Status* status) {
273 tensorflow::FunctionDefLibrary lib;
274 {
275 tensorflow::mutex_lock l(g->mu);
276 lib = g->graph.flib_def().ToProto();
277 }
278 const auto len = std::min(max_func, static_cast<int>(lib.function_size()));
279 for (int i = 0; i < len; ++i) {
280 TF_Function* func = new TF_Function();
281 func->fdef = lib.function(i);
282 funcs[i] = func;
283 }
284 status->status = tensorflow::Status::OK();
285 return len;
286 }
287
TF_FunctionToFunctionDef(TF_Function * func,TF_Buffer * output_func_def,TF_Status * status)288 void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
289 TF_Status* status) {
290 status->status = MessageToBuffer(func->fdef, output_func_def);
291 }
292
TF_FunctionImportFunctionDef(const void * proto,size_t proto_len,TF_Status * status)293 TF_Function* TF_FunctionImportFunctionDef(const void* proto, size_t proto_len,
294 TF_Status* status) {
295 TF_Function* func = new TF_Function();
296 if (!func->fdef.ParseFromArray(proto, proto_len)) {
297 status->status = InvalidArgument(
298 "Invalid FunctionDef given to TF_FunctionImportFunctionDef");
299 TF_DeleteFunction(func);
300 return nullptr;
301 }
302 status->status = tensorflow::Status::OK();
303 return func;
304 }
305
TF_FunctionSetAttrValueProto(TF_Function * func,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)306 void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name,
307 const void* proto, size_t proto_len,
308 TF_Status* status) {
309 tensorflow::AttrValue attr_value;
310 if (!attr_value.ParseFromArray(proto, proto_len)) {
311 status->status = InvalidArgument(
312 "Unparseable AttrValue proto passed to "
313 "TF_FunctionSetAttrValueProto");
314 return;
315 }
316 (*func->fdef.mutable_attr())[string(attr_name)] = attr_value;
317 status->status = tensorflow::Status::OK();
318 }
319
TF_FunctionGetAttrValueProto(TF_Function * func,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)320 void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name,
321 TF_Buffer* output_attr_value,
322 TF_Status* status) {
323 const auto& it = func->fdef.attr().find(attr_name);
324 if (it == func->fdef.attr().end()) {
325 status->status =
326 InvalidArgument("Function '", func->fdef.signature().name(),
327 "' has no attr named '", attr_name, "'.");
328 return;
329 }
330 status->status = MessageToBuffer(it->second, output_attr_value);
331 }
332
TF_DeleteFunction(TF_Function * func)333 void TF_DeleteFunction(TF_Function* func) { delete func; }
334