1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <unordered_map>
18 #include <unordered_set>
19
20 #include "absl/strings/match.h"
21 #include "tensorflow/c/c_api_internal.h"
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/graph_to_functiondef.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/node_def_util.h"
27 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/lib/strings/base64.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32
33 using tensorflow::errors::InvalidArgument;
34
35 namespace tensorflow {
36 namespace {
37
ValidateNonRefOutput(const Node * node,int idx)38 Status ValidateNonRefOutput(const Node* node, int idx) {
39 const DataType& dt = node->output_type(idx);
40 return IsRefType(dt)
41 ? InvalidArgument("Output ", idx, " of node '", node->name(),
42 "' has a reference type ", DataTypeString(dt))
43 : Status::OK();
44 }
45
46 // Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and
47 // does various checks while doing so. `input_nodes` will contain the same
48 // information as input_tensors just in a different structure to make
49 // following processing easier. TODO(iga): Simplify this nested structure.
ProcessInputs(const TF_Graph * fn_body,const char * fn_name,int ninputs,const TF_Output * inputs,std::vector<OutputTensor> * input_tensors,std::unordered_map<const Node *,std::vector<int>> * input_nodes)50 Status ProcessInputs(
51 const TF_Graph* fn_body, const char* fn_name, int ninputs,
52 const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
53 std::unordered_map<const Node*, std::vector<int>>* input_nodes)
54 EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
55 input_tensors->reserve(ninputs);
56 for (int i = 0; i < ninputs; ++i) {
57 Node* node = &inputs[i].oper->node;
58 int idx = inputs[i].index;
59
60 TF_RETURN_WITH_CONTEXT_IF_ERROR(
61 fn_body->graph.IsValidOutputTensor(node, idx),
62 "Encountered while processing input ", i, " into function '", fn_name,
63 "'");
64 TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
65 "Encountered while processing input ", i,
66 " into function '", fn_name, "'");
67
68 input_tensors->emplace_back(node, idx);
69
70 const auto& iter = input_nodes->find(node);
71 if (iter == input_nodes->end()) {
72 input_nodes->insert({node, {idx}});
73 } else {
74 auto& indices = iter->second;
75 if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
76 return InvalidArgument("TF_Output ", node->name(), ":", idx,
77 " appears more than once in the input list");
78 }
79 indices.push_back(idx);
80 }
81 }
82 return Status::OK();
83 }
84
85 // Converts `noutputs` and `outputs` into `outputs_tensors` and does various
86 // checks while doing so.
ProcessOutputs(const TF_Graph * fn_body,const char * fn_name,int noutputs,const TF_Output * outputs,std::vector<OutputTensor> * output_tensors)87 Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
88 int noutputs, const TF_Output* outputs,
89 std::vector<OutputTensor>* output_tensors)
90 EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
91 output_tensors->reserve(noutputs);
92 for (int i = 0; i < noutputs; ++i) {
93 Node* node = &outputs[i].oper->node;
94 int idx = outputs[i].index;
95 TF_RETURN_WITH_CONTEXT_IF_ERROR(
96 fn_body->graph.IsValidOutputTensor(node, idx),
97 "Encountered while processing output ", i, " from function '", fn_name,
98 "'");
99 TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
100 "Encountered while creating function '",
101 fn_name, "'");
102 output_tensors->emplace_back(node, idx);
103 }
104 return Status::OK();
105 }
106
107 // Populates `body_nodes` with the nodes that will become function's body.
108 // Performs various checks.
ComputeBodyNodes(const TF_Graph * fn_body,const char * fn_name,int num_opers,const TF_Operation * const * opers,const std::unordered_map<const Node *,std::vector<int>> & input_nodes,std::vector<const Node * > * body_nodes)109 Status ComputeBodyNodes(
110 const TF_Graph* fn_body, const char* fn_name, int num_opers,
111 const TF_Operation* const* opers,
112 const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
113 std::vector<const Node*>* body_nodes)
114 EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
115 if (num_opers == -1) {
116 for (const Node* node : fn_body->graph.op_nodes()) {
117 const auto& iter = input_nodes.find(node);
118 if (iter == input_nodes.end()) {
119 // This node is not referenced in inputs. Add it to the body.
120 body_nodes->push_back(node);
121 } else {
122 // This node is referenced in inputs. Currently, we place an
123 // artificial restriction and require that when num_opers=-1, such
124 // nodes must have a single output.
125 if (node->num_outputs() != 1) {
126 return InvalidArgument(
127 "When `num_opers` is set to -1, nodes referenced in `inputs` "
128 "must have a single output. Node ",
129 node->name(), " has ", node->num_outputs(),
130 " outputs. Encountered while creating function '", fn_name, "'");
131 }
132 }
133 }
134 } else {
135 body_nodes->reserve(num_opers);
136 for (int i = 0; i < num_opers; ++i) {
137 const Node* node = &opers[i]->node;
138 body_nodes->push_back(node);
139 }
140 }
141 return Status::OK();
142 }
143
144 } // namespace
145 } // namespace tensorflow
146
147 using tensorflow::Node;
148 using tensorflow::string;
149
TF_GraphToFunctionWithControlOutputs(const TF_Graph * fn_body,const char * fn_name,unsigned char append_hash_to_fn_name,int num_opers,const TF_Operation * const * opers,int ninputs,const TF_Output * inputs,int noutputs,const TF_Output * outputs,const char * const * output_names,int ncontrol_outputs,const TF_Operation * const * control_outputs,const char * const * control_output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * status)150 TF_Function* TF_GraphToFunctionWithControlOutputs(
151 const TF_Graph* fn_body, const char* fn_name,
152 unsigned char append_hash_to_fn_name, int num_opers,
153 const TF_Operation* const* opers, int ninputs, const TF_Output* inputs,
154 int noutputs, const TF_Output* outputs, const char* const* output_names,
155 int ncontrol_outputs, const TF_Operation* const* control_outputs,
156 const char* const* control_output_names, const TF_FunctionOptions* opts,
157 const char* description, TF_Status* status) {
158 tensorflow::mutex_lock l(*const_cast<tensorflow::mutex*>(&fn_body->mu));
159
160 // Process inputs.
161 std::vector<tensorflow::OutputTensor> input_tensors;
162 std::unordered_map<const Node*, std::vector<int>> input_nodes;
163 status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
164 &input_tensors, &input_nodes);
165 if (TF_GetCode(status) != TF_OK) return nullptr;
166
167 // Process outputs.
168 std::vector<tensorflow::OutputTensor> output_tensors;
169 status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
170 outputs, &output_tensors);
171 if (TF_GetCode(status) != TF_OK) return nullptr;
172
173 // Process output names.
174 std::vector<string> output_names_vec;
175 if (output_names) {
176 output_names_vec.reserve(noutputs);
177 for (int i = 0; i < noutputs; ++i) {
178 output_names_vec.push_back(string(output_names[i]));
179 }
180 }
181
182 // Process control output names.
183 std::vector<string> control_output_names_vec;
184 if (control_output_names) {
185 control_output_names_vec.reserve(ncontrol_outputs);
186 for (int i = 0; i < ncontrol_outputs; ++i) {
187 control_output_names_vec.push_back(string(output_names[i]));
188 }
189 }
190
191 // Compute body nodes.
192 std::vector<const Node*> body_nodes;
193 status->status = tensorflow::ComputeBodyNodes(
194 fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
195 if (TF_GetCode(status) != TF_OK) return nullptr;
196
197 // Compute body nodes.
198 std::vector<const Node*> control_output_nodes;
199 for (int i = 0; i < ncontrol_outputs; ++i) {
200 control_output_nodes.push_back(&control_outputs[i]->node);
201 }
202
203 // Do the actual function creation.
204 TF_Function* tf_function = new TF_Function();
205 DCHECK(append_hash_to_fn_name <= 1);
206 status->status = tensorflow::GraphToFunctionDef(
207 fn_body->graph, fn_name, append_hash_to_fn_name != 0,
208 /*set_stateful_from_nodes=*/true,
209 /*copy_placeholder_attrs_from_nodes=*/true, body_nodes, input_tensors,
210 output_tensors, output_names_vec, control_output_nodes,
211 control_output_names_vec, description, &tf_function->fdef);
212 if (TF_GetCode(status) != TF_OK) {
213 TF_DeleteFunction(tf_function);
214 return nullptr;
215 }
216 return tf_function;
217 }
218
TF_GraphToFunction(const TF_Graph * fn_body,const char * fn_name,unsigned char append_hash_to_fn_name,int num_opers,const TF_Operation * const * opers,int ninputs,const TF_Output * inputs,int noutputs,const TF_Output * outputs,const char * const * output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * status)219 TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
220 unsigned char append_hash_to_fn_name,
221 int num_opers, const TF_Operation* const* opers,
222 int ninputs, const TF_Output* inputs,
223 int noutputs, const TF_Output* outputs,
224 const char* const* output_names,
225 const TF_FunctionOptions* opts,
226 const char* description, TF_Status* status) {
227 return TF_GraphToFunctionWithControlOutputs(
228 fn_body, fn_name, append_hash_to_fn_name, num_opers, opers, ninputs,
229 inputs, noutputs, outputs, output_names, 0, nullptr, nullptr, opts,
230 description, status);
231 }
232
TF_FunctionName(TF_Function * func)233 const char* TF_FunctionName(TF_Function* func) {
234 return func->fdef.signature().name().c_str();
235 }
236
TF_GraphCopyFunction(TF_Graph * g,const TF_Function * func,const TF_Function * grad,TF_Status * status)237 void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func,
238 const TF_Function* grad, TF_Status* status) {
239 if (func == nullptr) {
240 status->status = InvalidArgument(
241 "'func' argument to TF_GraphCopyFunction cannot be null");
242 return;
243 }
244
245 // TODO(iga): Add AddFunctionDef() and AddGradientDef() methods to graph
246 // to avoid the extra copy here.
247 tensorflow::FunctionDefLibrary fdef_lib;
248 *fdef_lib.add_function() = func->fdef;
249 if (grad) {
250 *fdef_lib.add_function() = grad->fdef;
251 tensorflow::GradientDef* gdef = fdef_lib.add_gradient();
252 gdef->set_function_name(func->fdef.signature().name());
253 gdef->set_gradient_func(grad->fdef.signature().name());
254 }
255
256 tensorflow::mutex_lock l(g->mu);
257 status->status = g->graph.AddFunctionLibrary(fdef_lib);
258 }
259
TF_GraphNumFunctions(TF_Graph * g)260 int TF_GraphNumFunctions(TF_Graph* g) {
261 tensorflow::mutex_lock l(g->mu);
262 return g->graph.flib_def().num_functions();
263 }
264
TF_GraphGetFunctions(TF_Graph * g,TF_Function ** funcs,int max_func,TF_Status * status)265 int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func,
266 TF_Status* status) {
267 tensorflow::FunctionDefLibrary lib;
268 {
269 tensorflow::mutex_lock l(g->mu);
270 lib = g->graph.flib_def().ToProto();
271 }
272 const auto len = std::min(max_func, static_cast<int>(lib.function_size()));
273 for (int i = 0; i < len; ++i) {
274 TF_Function* func = new TF_Function();
275 func->fdef = lib.function(i);
276 funcs[i] = func;
277 }
278 status->status = tensorflow::Status::OK();
279 return len;
280 }
281
TF_FunctionToFunctionDef(TF_Function * func,TF_Buffer * output_func_def,TF_Status * status)282 void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
283 TF_Status* status) {
284 status->status = MessageToBuffer(func->fdef, output_func_def);
285 }
286
TF_FunctionImportFunctionDef(const void * proto,size_t proto_len,TF_Status * status)287 TF_Function* TF_FunctionImportFunctionDef(const void* proto, size_t proto_len,
288 TF_Status* status) {
289 TF_Function* func = new TF_Function();
290 if (!func->fdef.ParseFromArray(proto, proto_len)) {
291 status->status = InvalidArgument(
292 "Invalid FunctionDef given to TF_FunctionImportFunctionDef");
293 TF_DeleteFunction(func);
294 return nullptr;
295 }
296 status->status = tensorflow::Status::OK();
297 return func;
298 }
299
TF_FunctionSetAttrValueProto(TF_Function * func,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)300 void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name,
301 const void* proto, size_t proto_len,
302 TF_Status* status) {
303 tensorflow::AttrValue attr_value;
304 if (!attr_value.ParseFromArray(proto, proto_len)) {
305 status->status = InvalidArgument(
306 "Unparseable AttrValue proto passed to "
307 "TF_FunctionSetAttrValueProto");
308 return;
309 }
310 (*func->fdef.mutable_attr())[string(attr_name)] = attr_value;
311 status->status = tensorflow::Status::OK();
312 }
313
TF_FunctionGetAttrValueProto(TF_Function * func,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)314 void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name,
315 TF_Buffer* output_attr_value,
316 TF_Status* status) {
317 const auto& it = func->fdef.attr().find(attr_name);
318 if (it == func->fdef.attr().end()) {
319 status->status =
320 InvalidArgument("Function '", func->fdef.signature().name(),
321 "' has no attr named '", attr_name, "'.");
322 return;
323 }
324 status->status = MessageToBuffer(it->second, output_attr_value);
325 }
326
TF_DeleteFunction(TF_Function * func)327 void TF_DeleteFunction(TF_Function* func) { delete func; }
328