1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/python/client/tf_session_helper.h"
17
18 #include <cstring>
19
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/c_api_internal.h"
22 #include "tensorflow/c/tf_status_helper.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/log_memory.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/graph/tensor_id.h"
29 #include "tensorflow/core/lib/core/coding.h"
30 #include "tensorflow/core/lib/strings/stringprintf.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/equal_graph_def.h"
33 #include "tensorflow/python/client/session_ref.h"
34 #include "tensorflow/python/lib/core/ndarray_tensor.h"
35 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
36 #include "tensorflow/python/lib/core/safe_ptr.h"
37
38 namespace tensorflow {
39
40 namespace {
41
42 static const char* kFeedDictErrorMsg =
43 "feed_dict must be a dictionary mapping strings to NumPy arrays.";
44 } // end namespace
45
TF_NewSessionRef(TF_Graph * graph,const TF_SessionOptions * opts,TF_Status * status)46 TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts,
47 TF_Status* status) {
48 TF_Session* tf_session = TF_NewSession(graph, opts, status);
49 if (tf_session == nullptr) {
50 return nullptr;
51 }
52
53 Session* session = reinterpret_cast<Session*>(tf_session->session);
54 SessionRef* session_ref = new SessionRef(session);
55 tf_session->session = session_ref;
56 return tf_session;
57 }
58
TF_Run_wrapper_helper(TF_DeprecatedSession * session,const char * handle,const TF_Buffer * run_options,PyObject * feed_dict,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_outputs)59 void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
60 const TF_Buffer* run_options, PyObject* feed_dict,
61 const NameVector& output_names,
62 const NameVector& target_nodes,
63 TF_Status* out_status, PyObjectVector* out_values,
64 TF_Buffer* run_outputs) {
65 // 1. Convert the feed inputs to the appropriate form for TF_Run.
66 if (!PyDict_Check(feed_dict)) {
67 Set_TF_Status_from_Status(out_status,
68 errors::InvalidArgument(kFeedDictErrorMsg));
69 return;
70 }
71
72 NameVector input_names;
73 std::vector<Safe_TF_TensorPtr> inputs_safe; // Used to delete tensors.
74 TF_TensorVector inputs_unsafe; // Used to contain the arg to TF_Run.
75
76 PyObject* key;
77 PyObject* value;
78 Py_ssize_t pos = 0;
79 int index = 0;
80 Status s;
81
82 while (PyDict_Next(feed_dict, &pos, &key, &value)) {
83 char* key_string = PyBytes_AsString(key);
84 if (!key_string) {
85 Set_TF_Status_from_Status(out_status,
86 errors::InvalidArgument(kFeedDictErrorMsg));
87 return;
88 }
89 input_names.push_back(key_string);
90
91 inputs_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
92 s = PyArrayToTF_Tensor(value, &inputs_safe.back());
93 if (!s.ok()) {
94 Set_TF_Status_from_Status(out_status, s);
95 return;
96 }
97 inputs_unsafe.push_back(inputs_safe.back().get());
98 ++index;
99 }
100
101 // 2. Allocate a container for the output data.
102 TF_TensorVector outputs(output_names.size());
103
104 // In case any tensors were leftover from previous runs we might as well clear
105 // them here.
106 ClearDecrefCache();
107
108 // 3. Actually call TF_Run().
109 Py_BEGIN_ALLOW_THREADS;
110 if (handle == nullptr) {
111 TF_Run(session, run_options, input_names.data(), inputs_unsafe.data(),
112 input_names.size(), const_cast<const char**>(output_names.data()),
113 outputs.data(), output_names.size(),
114 const_cast<const char**>(target_nodes.data()), target_nodes.size(),
115 run_outputs, out_status);
116 } else {
117 TF_PRun(session, handle, input_names.data(), inputs_unsafe.data(),
118 input_names.size(), const_cast<const char**>(output_names.data()),
119 outputs.data(), output_names.size(),
120 const_cast<const char**>(target_nodes.data()), target_nodes.size(),
121 out_status);
122 }
123
124 Py_END_ALLOW_THREADS;
125
126 // Decref any numpy arrays we are not using anymore.
127 ClearDecrefCache();
128
129 if (TF_GetCode(out_status) != TF_OK) {
130 return;
131 }
132
133 // 4. We now own the fetched tensors, so set up a safe container to
134 // delete them when we exit this scope.
135 std::vector<Safe_TF_TensorPtr> tf_outputs_safe;
136 for (const auto& output : outputs) {
137 tf_outputs_safe.emplace_back(make_safe(output));
138 }
139
140 // 5. Convert the fetched tensors into numpy ndarrays. Store them in a safe
141 // container so that we do not leak
142 std::vector<Safe_PyObjectPtr> py_outputs_safe;
143 for (size_t i = 0; i < output_names.size(); ++i) {
144 PyObject* py_array;
145 s = TF_TensorToPyArray(std::move(tf_outputs_safe[i]), &py_array);
146 if (!s.ok()) {
147 Set_TF_Status_from_Status(out_status, s);
148 return;
149 }
150 py_outputs_safe.emplace_back(make_safe(py_array));
151 }
152
153 // 6. If we reach this point, we have successfully built a list of objects
154 // so we can release them from the safe container.
155 for (auto& output : py_outputs_safe) {
156 out_values->push_back(output.release());
157 }
158 }
159
160 // Wrapper for TF_Run that converts the arguments to appropriate types.
161 // If *out_status is OK, the caller becomes the owner of the PyObjects
162 // in *out_values.
TF_Run_wrapper(TF_DeprecatedSession * session,const TF_Buffer * run_options,PyObject * feed_dict,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_outputs)163 void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options,
164 PyObject* feed_dict, const NameVector& output_names,
165 const NameVector& target_nodes, TF_Status* out_status,
166 PyObjectVector* out_values, TF_Buffer* run_outputs) {
167 TF_Run_wrapper_helper(session, nullptr, run_options, feed_dict, output_names,
168 target_nodes, out_status, out_values, run_outputs);
169 ClearDecrefCache();
170 }
171
172 namespace {
MakeCallableHelper(tensorflow::Session * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * out_status)173 void MakeCallableHelper(tensorflow::Session* session,
174 const TF_Buffer* callable_options, int64_t* out_handle,
175 TF_Status* out_status) {
176 tensorflow::CallableOptions callable_options_proto;
177 if (callable_options != nullptr &&
178 !callable_options_proto.ParseFromArray(callable_options->data,
179 callable_options->length)) {
180 Set_TF_Status_from_Status(
181 out_status,
182 errors::InvalidArgument("Unparseable CallableOptions proto"));
183 return;
184 }
185 tensorflow::Session::CallableHandle handle;
186 Status s = session->MakeCallable(callable_options_proto, &handle);
187 if (!s.ok()) {
188 Set_TF_Status_from_Status(out_status, s);
189 return;
190 }
191 *out_handle = handle;
192 }
193 } // namespace
194
TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * out_status)195 void TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession* session,
196 const TF_Buffer* callable_options,
197 int64_t* out_handle,
198 TF_Status* out_status) {
199 MakeCallableHelper(session->session, callable_options, out_handle,
200 out_status);
201 }
TF_SessionMakeCallable(TF_Session * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * out_status)202 void TF_SessionMakeCallable(TF_Session* session,
203 const TF_Buffer* callable_options,
204 int64_t* out_handle, TF_Status* out_status) {
205 MakeCallableHelper(session->session, callable_options, out_handle,
206 out_status);
207 }
208
209 namespace {
RunCallableHelper(tensorflow::Session * session,int64_t handle,PyObject * feed_values,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_metadata)210 void RunCallableHelper(tensorflow::Session* session, int64_t handle,
211 PyObject* feed_values, TF_Status* out_status,
212 PyObjectVector* out_values, TF_Buffer* run_metadata) {
213 // Convert feed values to a vector of tensorflow::Tensor objects.
214 std::vector<Tensor> input_tensors;
215 Status s;
216 {
217 feed_values =
218 PySequence_Fast(feed_values, "feed_values must be a sequence");
219 if (feed_values == nullptr) return;
220 Safe_PyObjectPtr feed_values_holder(make_safe(feed_values));
221 Py_ssize_t len = PySequence_Fast_GET_SIZE(feed_values);
222 input_tensors.reserve(len);
223 for (Py_ssize_t i = 0; i < len; ++i) {
224 PyObject* elem = PySequence_Fast_GET_ITEM(feed_values, i);
225 if (!elem) {
226 Set_TF_Status_from_Status(
227 out_status, errors::Internal("Could not get feed value ", i));
228 return;
229 }
230 Tensor t;
231 s = NdarrayToTensor(elem, &t);
232 if (!s.ok()) {
233 Set_TF_Status_from_Status(out_status, s);
234 return;
235 }
236 input_tensors.push_back(std::move(t));
237 }
238 }
239
240 // Allocate a RunMetadata protobuf object to receive the metadata,
241 // if the caller is expecting any.
242 std::unique_ptr<RunMetadata> run_metadata_proto;
243 if (run_metadata != nullptr) {
244 run_metadata_proto.reset(new RunMetadata);
245 }
246
247 // Run the callable.
248 std::vector<Tensor> output_tensors;
249 Py_BEGIN_ALLOW_THREADS;
250 s = session->RunCallable(handle, input_tensors, &output_tensors,
251 run_metadata_proto.get());
252 Py_END_ALLOW_THREADS;
253
254 if (!s.ok()) {
255 Set_TF_Status_from_Status(out_status, s);
256 return;
257 }
258
259 // If requested, serialize the RunMetadata to pass it back to the caller.
260 if (run_metadata != nullptr) {
261 s = MessageToBuffer(*run_metadata_proto, run_metadata);
262 if (!s.ok()) {
263 Set_TF_Status_from_Status(out_status, s);
264 return;
265 }
266 }
267
268 // Convert results to NumPy arrays. Since this can fail, stage the
269 // results via a safe container that takes care of decreasing the
270 // reference count on failure.
271 std::vector<Safe_PyObjectPtr> py_outputs_safe;
272 py_outputs_safe.reserve(output_tensors.size());
273 for (const Tensor& output : output_tensors) {
274 PyObject* py_array;
275 s = TensorToNdarray(output, &py_array);
276 if (!s.ok()) {
277 Set_TF_Status_from_Status(out_status, s);
278 return;
279 }
280 py_outputs_safe.push_back(make_safe(py_array));
281 }
282
283 // If we reach this point, we have successfully built a list of objects
284 // so we can release them from the safe container.
285 out_values->reserve(py_outputs_safe.size());
286 for (auto& output : py_outputs_safe) {
287 out_values->push_back(output.release());
288 }
289 }
290 } // namespace
291
TF_DeprecatedSessionRunCallable(TF_DeprecatedSession * session,int64_t handle,PyObject * feed_values,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_metadata)292 void TF_DeprecatedSessionRunCallable(TF_DeprecatedSession* session,
293 int64_t handle, PyObject* feed_values,
294 TF_Status* out_status,
295 PyObjectVector* out_values,
296 TF_Buffer* run_metadata) {
297 RunCallableHelper(session->session, handle, feed_values, out_status,
298 out_values, run_metadata);
299 ClearDecrefCache();
300 }
TF_SessionRunCallable(TF_Session * session,int64_t handle,PyObject * feed_values,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_metadata)301 void TF_SessionRunCallable(TF_Session* session, int64_t handle,
302 PyObject* feed_values, TF_Status* out_status,
303 PyObjectVector* out_values,
304 TF_Buffer* run_metadata) {
305 RunCallableHelper(session->session, handle, feed_values, out_status,
306 out_values, run_metadata);
307 ClearDecrefCache();
308 }
309
TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession * session,int64_t handle,TF_Status * out_status)310 void TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession* session,
311 int64_t handle,
312 TF_Status* out_status) {
313 Set_TF_Status_from_Status(out_status,
314 session->session->ReleaseCallable(handle));
315 }
TF_SessionReleaseCallable(TF_Session * session,int64_t handle,TF_Status * out_status)316 void TF_SessionReleaseCallable(TF_Session* session, int64_t handle,
317 TF_Status* out_status) {
318 Set_TF_Status_from_Status(out_status,
319 session->session->ReleaseCallable(handle));
320 }
321
322 // Wrapper for TF_PRunSetup that converts the arguments to appropriate types.
323 // If *out_status is OK, the caller becomes the owner of *out_handle.
TF_PRunSetup_wrapper(TF_DeprecatedSession * session,const NameVector & input_names,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,const char ** out_handle)324 void TF_PRunSetup_wrapper(TF_DeprecatedSession* session,
325 const NameVector& input_names,
326 const NameVector& output_names,
327 const NameVector& target_nodes, TF_Status* out_status,
328 const char** out_handle) {
329 Py_BEGIN_ALLOW_THREADS;
330 TF_PRunSetup(
331 session, const_cast<const char**>(input_names.data()), input_names.size(),
332 const_cast<const char**>(output_names.data()), output_names.size(),
333 const_cast<const char**>(target_nodes.data()), target_nodes.size(),
334 out_handle, out_status);
335 Py_END_ALLOW_THREADS;
336 }
337
338 // Wrapper for TF_PRun that converts the arguments to appropriate types.
339 // If *out_status is OK, the caller becomes the owner of the PyObjects
340 // in *out_values.
TF_PRun_wrapper(TF_DeprecatedSession * session,const char * handle,PyObject * feed_dict,const NameVector & output_names,TF_Status * out_status,PyObjectVector * out_values)341 void TF_PRun_wrapper(TF_DeprecatedSession* session, const char* handle,
342 PyObject* feed_dict, const NameVector& output_names,
343 TF_Status* out_status, PyObjectVector* out_values) {
344 TF_Run_wrapper_helper(session, handle, nullptr, feed_dict, output_names,
345 NameVector(), out_status, out_values, nullptr);
346 ClearDecrefCache();
347 }
348
349 // Wrapper for TF_Reset that converts the string vectors to character arrays.
TF_Reset_wrapper(const TF_SessionOptions * opt,const NameVector & containers,TF_Status * out_status)350 void TF_Reset_wrapper(const TF_SessionOptions* opt,
351 const NameVector& containers, TF_Status* out_status) {
352 TF_Reset(opt, const_cast<const char**>(containers.data()), containers.size(),
353 out_status);
354 }
355
TF_SessionRun_wrapper_helper(TF_Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,TF_Buffer * run_metadata,TF_Status * out_status,std::vector<PyObject * > * py_outputs)356 void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
357 const TF_Buffer* run_options,
358 const std::vector<TF_Output>& inputs,
359 const std::vector<PyObject*>& input_ndarrays,
360 const std::vector<TF_Output>& outputs,
361 const std::vector<TF_Operation*>& targets,
362 TF_Buffer* run_metadata,
363 TF_Status* out_status,
364 std::vector<PyObject*>* py_outputs) {
365 DCHECK_EQ(inputs.size(), input_ndarrays.size());
366 DCHECK(py_outputs != nullptr);
367 DCHECK(py_outputs->empty());
368 Status s;
369
370 // Convert input ndarray PyObjects to TF_Tensors. We maintain a continuous
371 // array of TF_Tensor*s as well as scoped containers to make sure they're
372 // cleaned up properly.
373 //
374 // Memory management:
375 // PyArrayToTF_Tensor() creates a new ndarray PyObject from the input
376 // ndarray. We manage the new ndarray's lifetime in order to keep the
377 // underlying data buffer alive (the new ndarray also guarantees a contiguous
378 // data buffer). The new ndarray's data buffer is used to create the
379 // corresponding TF_Tensor. The TF_Tensor's deallocator will queue the new
380 // ndarray to be decref'd by the next ClearDecrefCache() call (we can't call
381 // Py_DECREF in the deallocator directly because the GIL must be held).
382 //
383 // Note that TF_Tensor may directly delegate its data and deallocator to a
384 // TensorBuffer, which may outlive the TF_Tensor (e.g. if the tensor gets
385 // queued or assigned to a variable).
386 TF_TensorVector input_vals;
387 std::vector<Safe_TF_TensorPtr> input_vals_safe;
388 for (PyObject* ndarray : input_ndarrays) {
389 input_vals_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
390 s = PyArrayToTF_Tensor(ndarray, &input_vals_safe.back());
391 if (!s.ok()) {
392 Set_TF_Status_from_Status(out_status, s);
393 return;
394 }
395 input_vals.push_back(input_vals_safe.back().get());
396 }
397
398 // Allocate space for output TF_Tensor*s
399 TF_TensorVector output_vals(outputs.size());
400
401 // Clear up any unused memory leftover from previous runs
402 ClearDecrefCache();
403
404 // Call TF_SessionRun() (and release GIL during execution)
405 Py_BEGIN_ALLOW_THREADS;
406 if (handle == nullptr) {
407 TF_SessionRun(session, run_options, inputs.data(), input_vals.data(),
408 inputs.size(), outputs.data(), output_vals.data(),
409 outputs.size(), targets.data(), targets.size(), run_metadata,
410 out_status);
411 } else {
412 TF_SessionPRun(session, handle, inputs.data(), input_vals.data(),
413 inputs.size(), outputs.data(), output_vals.data(),
414 outputs.size(), targets.data(), targets.size(), out_status);
415 }
416 Py_END_ALLOW_THREADS;
417
418 // Create scoped containers for output tensors
419 std::vector<Safe_TF_TensorPtr> output_vals_safe;
420 for (TF_Tensor* output : output_vals) {
421 output_vals_safe.emplace_back(make_safe(output));
422 }
423
424 // Convert outputs to ndarrays (in scoped containers)
425 std::vector<Safe_PyObjectPtr> py_outputs_safe;
426 for (size_t i = 0; i < outputs.size(); ++i) {
427 PyObject* py_array;
428 s = TF_TensorToPyArray(std::move(output_vals_safe[i]), &py_array);
429 if (!s.ok()) {
430 Set_TF_Status_from_Status(out_status, s);
431 return;
432 }
433 py_outputs_safe.emplace_back(make_safe(py_array));
434 }
435
436 // If we reach this point, we have successfully built a list of objects so we
437 // can release them from the safe container into the return vector.
438 for (size_t i = 0; i < outputs.size(); ++i) {
439 py_outputs->push_back(py_outputs_safe[i].release());
440 }
441 }
442
TF_SessionRun_wrapper(TF_Session * session,const TF_Buffer * run_options,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,TF_Buffer * run_metadata,TF_Status * out_status,std::vector<PyObject * > * py_outputs)443 void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options,
444 const std::vector<TF_Output>& inputs,
445 const std::vector<PyObject*>& input_ndarrays,
446 const std::vector<TF_Output>& outputs,
447 const std::vector<TF_Operation*>& targets,
448 TF_Buffer* run_metadata, TF_Status* out_status,
449 std::vector<PyObject*>* py_outputs) {
450 TF_SessionRun_wrapper_helper(session, nullptr, run_options, inputs,
451 input_ndarrays, outputs, targets, run_metadata,
452 out_status, py_outputs);
453 // Release any unused ndarray references (see memory management comment in
454 // TF_SessionRun_wrapper_helper)
455 ClearDecrefCache();
456 }
457
EqualGraphDefWrapper(const string & actual,const string & expected)458 string EqualGraphDefWrapper(const string& actual, const string& expected) {
459 GraphDef actual_def;
460 if (!actual_def.ParseFromString(actual)) {
461 return "actual is not a valid serialized GraphDef";
462 }
463 GraphDef expected_def;
464 if (!expected_def.ParseFromString(expected)) {
465 return "expected is not a valid serialized GraphDef";
466 }
467 string diff;
468 return EqualGraphDef(actual_def, expected_def, &diff) ? "" : diff;
469 }
470
EqualAttrValueWrapper(const string & actual,const string & expected)471 string EqualAttrValueWrapper(const string& actual, const string& expected) {
472 AttrValue actual_attr_value;
473 if (!actual_attr_value.ParseFromString(actual)) {
474 return "actual is not a valid serialized AttrValue";
475 }
476
477 AttrValue expected_attr_value;
478 if (!expected_attr_value.ParseFromString(expected)) {
479 return "expected is not a valid serialized AttrValue";
480 }
481
482 string diff;
483 if (!AreAttrValuesEqual(actual_attr_value, expected_attr_value)) {
484 diff = strings::Printf(
485 "Actual AttrValue %s does not match Expected AttrValue %s.",
486 SummarizeAttrValue(actual_attr_value).c_str(),
487 SummarizeAttrValue(expected_attr_value).c_str());
488 }
489 return diff;
490 }
491
492 // Return value set to 6 inlined elements so it fits in a 64-byte cache line.
TF_GraphGetTensorShapeHelper(TF_Graph * graph,TF_Output output,TF_Status * out_status,bool * unknown_shape)493 tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper(
494 TF_Graph* graph, TF_Output output, TF_Status* out_status,
495 bool* unknown_shape) {
496 // Allocate a single variable for holding the result for RVO.
497 tensorflow::gtl::InlinedVector<int64_t, 6> result;
498 *unknown_shape = false;
499 int num_dims = TF_GraphGetTensorNumDims(graph, output, out_status);
500 if (TF_GetCode(out_status) != TF_OK) {
501 return result;
502 }
503 // If shape is unknown, set boolean and return.
504 if (num_dims == -1) {
505 *unknown_shape = true;
506 return result;
507 }
508
509 // If shape is a scalar, avoid another C call and just return {}.
510 if (num_dims == 0) {
511 return result;
512 }
513
514 result.resize(num_dims);
515 TF_GraphGetTensorShape(graph, output, result.data(), num_dims, out_status);
516 return result;
517 }
518
TF_SessionPRunSetup_wrapper(TF_Session * session,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,const char ** out_handle,TF_Status * out_status)519 void TF_SessionPRunSetup_wrapper(TF_Session* session,
520 const std::vector<TF_Output>& inputs,
521 const std::vector<TF_Output>& outputs,
522 const std::vector<TF_Operation*>& targets,
523 const char** out_handle,
524 TF_Status* out_status) {
525 // Call TF_SessionPRunSetup() (and release GIL during execution)
526 Py_BEGIN_ALLOW_THREADS;
527 TF_SessionPRunSetup(session, inputs.data(), inputs.size(), outputs.data(),
528 outputs.size(), targets.data(), targets.size(),
529 out_handle, out_status);
530 Py_END_ALLOW_THREADS;
531 }
532
TF_SessionPRun_wrapper(TF_Session * session,const char * handle,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,TF_Status * out_status,std::vector<PyObject * > * py_outputs)533 void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
534 const std::vector<TF_Output>& inputs,
535 const std::vector<PyObject*>& input_ndarrays,
536 const std::vector<TF_Output>& outputs,
537 TF_Status* out_status,
538 std::vector<PyObject*>* py_outputs) {
539 const std::vector<TF_Operation*> targets;
540 TF_SessionRun_wrapper_helper(session, handle,
541 nullptr, // run_options
542 inputs, input_ndarrays, outputs, targets,
543 nullptr, // run_metadata
544 out_status, py_outputs);
545 // Release any unused ndarray references (see memory management comment in
546 // TF_SessionRun_wrapper_helper)
547 ClearDecrefCache();
548 }
549
GetOperationInputs(TF_Operation * oper)550 std::vector<TF_Output> GetOperationInputs(TF_Operation* oper) {
551 int num_inputs = TF_OperationNumInputs(oper);
552 std::vector<TF_Output> inputs(num_inputs);
553 for (int i = 0; i < num_inputs; ++i) {
554 inputs[i] = TF_OperationInput({oper, i});
555 }
556 return inputs;
557 }
558
TF_OperationGetControlInputs_wrapper(TF_Operation * oper)559 std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
560 TF_Operation* oper) {
561 std::vector<TF_Operation*> control_inputs(TF_OperationNumControlInputs(oper));
562 TF_OperationGetControlInputs(oper, control_inputs.data(),
563 control_inputs.size());
564 return control_inputs;
565 }
566
TF_OperationGetControlOutputs_wrapper(TF_Operation * oper)567 std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper(
568 TF_Operation* oper) {
569 std::vector<TF_Operation*> control_outputs(
570 TF_OperationNumControlOutputs(oper));
571 TF_OperationGetControlOutputs(oper, control_outputs.data(),
572 control_outputs.size());
573 return control_outputs;
574 }
575
TF_OperationOutputConsumers_wrapper(TF_Output oper_out)576 std::vector<const char*> TF_OperationOutputConsumers_wrapper(
577 TF_Output oper_out) {
578 int num_consumers = TF_OperationOutputNumConsumers(oper_out);
579 std::vector<TF_Input> consumers(num_consumers);
580 TF_OperationOutputConsumers(oper_out, consumers.data(), num_consumers);
581
582 std::vector<const char*> consumer_names(num_consumers);
583 for (int i = 0; i < num_consumers; ++i) {
584 consumer_names[i] = TF_OperationName(consumers[i].oper);
585 }
586 return consumer_names;
587 }
588
TF_GraphToFunction_wrapper(const TF_Graph * fn_body,const char * fn_name,bool append_hash_to_fn_name,const std::vector<TF_Operation * > * opers,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const NameVector & output_names,const std::vector<TF_Operation * > * control_outputs,const NameVector & control_output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * out_status)589 TF_Function* TF_GraphToFunction_wrapper(
590 const TF_Graph* fn_body, const char* fn_name, bool append_hash_to_fn_name,
591 const std::vector<TF_Operation*>* opers,
592 const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs,
593 const NameVector& output_names,
594 const std::vector<TF_Operation*>* control_outputs,
595 const NameVector& control_output_names, const TF_FunctionOptions* opts,
596 const char* description, TF_Status* out_status) {
597 if (!output_names.empty() && output_names.size() != outputs.size()) {
598 Set_TF_Status_from_Status(
599 out_status,
600 errors::InvalidArgument(
601 "output names must be either empty or equal in size to outputs. ",
602 "output names size = ", output_names.size(),
603 " outputs size = ", outputs.size()));
604 return nullptr;
605 }
606
607 int nopers = -1;
608 const TF_Operation* const* opers_array = nullptr;
609 if (opers != nullptr) {
610 nopers = opers->size();
611 opers_array = opers->data();
612 }
613
614 const char** output_names_ptr =
615 output_names.empty() ? nullptr
616 : const_cast<const char**>(output_names.data());
617
618 const char** control_output_names_ptr =
619 control_output_names.empty()
620 ? nullptr
621 : const_cast<const char**>(control_output_names.data());
622
623 return TF_GraphToFunctionWithControlOutputs(
624 fn_body, fn_name, append_hash_to_fn_name, nopers, opers_array,
625 inputs.size(), inputs.data(), outputs.size(), outputs.data(),
626 output_names_ptr,
627 control_outputs == nullptr ? 0 : control_outputs->size(),
628 control_outputs == nullptr ? nullptr : control_outputs->data(),
629 control_output_names_ptr, opts, description, out_status);
630 }
631
TF_GraphSetOutputHandleShapesAndTypes_wrapper(TF_Graph * graph,TF_Output output,const std::vector<std::vector<int64_t>> & shapes,const std::vector<int> & ranks,const std::vector<TF_DataType> & types,TF_Status * status)632 void TF_GraphSetOutputHandleShapesAndTypes_wrapper(
633 TF_Graph* graph, TF_Output output,
634 const std::vector<std::vector<int64_t>>& shapes,
635 const std::vector<int>& ranks, const std::vector<TF_DataType>& types,
636 TF_Status* status) {
637 std::vector<const int64_t*> shapes_pointers(shapes.size());
638 for (int i = 0; i < shapes.size(); ++i) {
639 shapes_pointers[i] = ranks[i] <= 0 ? nullptr : &shapes[i][0];
640 }
641 TF_GraphSetOutputHandleShapesAndTypes(graph, output, shapes.size(),
642 shapes_pointers.data(), ranks.data(),
643 types.data(), status);
644 }
645
TF_GraphSetTensorShape_wrapper(TF_Graph * graph,TF_Output output,const std::vector<int64_t> & dims,bool unknown_shape,TF_Status * status)646 void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
647 const std::vector<int64_t>& dims,
648 bool unknown_shape, TF_Status* status) {
649 if (unknown_shape) {
650 TF_GraphSetTensorShape(graph, output, nullptr, -1, status);
651 return;
652 }
653 TF_GraphSetTensorShape(graph, output, dims.data(), dims.size(), status);
654 }
655
TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(TF_ImportGraphDefResults * results)656 std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
657 TF_ImportGraphDefResults* results) {
658 int num_missing_unused_input_mappings;
659 const char** src_names;
660 int* src_indexes;
661 TF_ImportGraphDefResultsMissingUnusedInputMappings(
662 results, &num_missing_unused_input_mappings, &src_names, &src_indexes);
663 std::vector<string> input_strs(num_missing_unused_input_mappings);
664 for (int i = 0; i < num_missing_unused_input_mappings; ++i) {
665 input_strs[i] = TensorId(src_names[i], src_indexes[i]).ToString();
666 }
667 return input_strs;
668 }
669
TF_TryEvaluateConstant_wrapper(TF_Graph * graph,TF_Output output,TF_Status * status)670 PyObject* TF_TryEvaluateConstant_wrapper(TF_Graph* graph, TF_Output output,
671 TF_Status* status) {
672 TF_Tensor* result_tensor;
673 bool evaluated =
674 TF_TryEvaluateConstant(graph, output, &result_tensor, status);
675 if (!evaluated || TF_GetCode(status) != TF_OK) Py_RETURN_NONE;
676
677 Safe_TF_TensorPtr safe_result_tensor(result_tensor);
678 PyObject* out;
679 Status s = TF_TensorToPyArray(std::move(safe_result_tensor), &out);
680 Set_TF_Status_from_Status(status, s);
681 if (!s.ok()) Py_RETURN_NONE;
682 return out;
683 }
684
685 } // namespace tensorflow
686