1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/python/client/tf_session_helper.h"
17
18 #include <cstring>
19
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/c_api_internal.h"
22 #include "tensorflow/c/tf_status_helper.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/log_memory.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/graph/tensor_id.h"
29 #include "tensorflow/core/lib/core/coding.h"
30 #include "tensorflow/core/lib/strings/stringprintf.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/equal_graph_def.h"
33 #include "tensorflow/python/client/session_ref.h"
34 #include "tensorflow/python/lib/core/ndarray_tensor.h"
35 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
36 #include "tensorflow/python/lib/core/safe_ptr.h"
37
38 namespace tensorflow {
39
40 namespace {
41
42 static const char* kFeedDictErrorMsg =
43 "feed_dict must be a dictionary mapping strings to NumPy arrays.";
44 } // end namespace
45
TF_NewSessionRef(TF_Graph * graph,const TF_SessionOptions * opts,TF_Status * status)46 TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts,
47 TF_Status* status) {
48 TF_Session* tf_session = TF_NewSession(graph, opts, status);
49 if (tf_session == nullptr) {
50 return nullptr;
51 }
52
53 Session* session = reinterpret_cast<Session*>(tf_session->session);
54 SessionRef* session_ref = new SessionRef(session);
55 tf_session->session = session_ref;
56 return tf_session;
57 }
58
TF_Run_wrapper_helper(TF_DeprecatedSession * session,const char * handle,const TF_Buffer * run_options,PyObject * feed_dict,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_outputs)59 void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
60 const TF_Buffer* run_options, PyObject* feed_dict,
61 const NameVector& output_names,
62 const NameVector& target_nodes,
63 TF_Status* out_status, PyObjectVector* out_values,
64 TF_Buffer* run_outputs) {
65 // 1. Convert the feed inputs to the appropriate form for TF_Run.
66 if (!PyDict_Check(feed_dict)) {
67 Set_TF_Status_from_Status(out_status,
68 errors::InvalidArgument(kFeedDictErrorMsg));
69 return;
70 }
71
72 NameVector input_names;
73 std::vector<Safe_TF_TensorPtr> inputs_safe; // Used to delete tensors.
74 TF_TensorVector inputs_unsafe; // Used to contain the arg to TF_Run.
75
76 PyObject* key;
77 PyObject* value;
78 Py_ssize_t pos = 0;
79 int index = 0;
80 Status s;
81
82 while (PyDict_Next(feed_dict, &pos, &key, &value)) {
83 char* key_string = PyBytes_AsString(key);
84 if (!key_string) {
85 Set_TF_Status_from_Status(out_status,
86 errors::InvalidArgument(kFeedDictErrorMsg));
87 return;
88 }
89 input_names.push_back(key_string);
90
91 inputs_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
92 s = NdarrayToTensor(nullptr /*ctx*/, value, &inputs_safe.back());
93 if (!s.ok()) {
94 Set_TF_Status_from_Status(out_status, s);
95 return;
96 }
97 inputs_unsafe.push_back(inputs_safe.back().get());
98 ++index;
99 }
100
101 // 2. Allocate a container for the output data.
102 TF_TensorVector outputs(output_names.size());
103
104 // In case any tensors were leftover from previous runs we might as well clear
105 // them here.
106 ClearDecrefCache();
107
108 // 3. Actually call TF_Run().
109 Py_BEGIN_ALLOW_THREADS;
110 if (handle == nullptr) {
111 TF_Run(session, run_options, input_names.data(), inputs_unsafe.data(),
112 input_names.size(), const_cast<const char**>(output_names.data()),
113 outputs.data(), output_names.size(),
114 const_cast<const char**>(target_nodes.data()), target_nodes.size(),
115 run_outputs, out_status);
116 } else {
117 TF_PRun(session, handle, input_names.data(), inputs_unsafe.data(),
118 input_names.size(), const_cast<const char**>(output_names.data()),
119 outputs.data(), output_names.size(),
120 const_cast<const char**>(target_nodes.data()), target_nodes.size(),
121 out_status);
122 }
123
124 Py_END_ALLOW_THREADS;
125
126 // Decref any numpy arrays we are not using anymore.
127 ClearDecrefCache();
128
129 if (TF_GetCode(out_status) != TF_OK) {
130 return;
131 }
132
133 // 4. We now own the fetched tensors, so set up a safe container to
134 // delete them when we exit this scope.
135 std::vector<Safe_TF_TensorPtr> tf_outputs_safe;
136 for (const auto& output : outputs) {
137 tf_outputs_safe.emplace_back(make_safe(output));
138 }
139
140 // 5. Convert the fetched tensors into numpy ndarrays. Store them in a safe
141 // container so that we do not leak
142 std::vector<Safe_PyObjectPtr> py_outputs_safe;
143 for (size_t i = 0; i < output_names.size(); ++i) {
144 PyObject* py_array;
145 s = TF_TensorToPyArray(std::move(tf_outputs_safe[i]), &py_array);
146 if (!s.ok()) {
147 Set_TF_Status_from_Status(out_status, s);
148 return;
149 }
150 py_outputs_safe.emplace_back(
151 make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array))));
152 }
153
154 // 6. If we reach this point, we have successfully built a list of objects
155 // so we can release them from the safe container.
156 for (auto& output : py_outputs_safe) {
157 out_values->push_back(output.release());
158 }
159 }
160
161 // Wrapper for TF_Run that converts the arguments to appropriate types.
162 // If *out_status is OK, the caller becomes the owner of the PyObjects
163 // in *out_values.
TF_Run_wrapper(TF_DeprecatedSession * session,const TF_Buffer * run_options,PyObject * feed_dict,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_outputs)164 void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options,
165 PyObject* feed_dict, const NameVector& output_names,
166 const NameVector& target_nodes, TF_Status* out_status,
167 PyObjectVector* out_values, TF_Buffer* run_outputs) {
168 TF_Run_wrapper_helper(session, nullptr, run_options, feed_dict, output_names,
169 target_nodes, out_status, out_values, run_outputs);
170 ClearDecrefCache();
171 }
172
173 namespace {
MakeCallableHelper(tensorflow::Session * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * out_status)174 void MakeCallableHelper(tensorflow::Session* session,
175 const TF_Buffer* callable_options, int64_t* out_handle,
176 TF_Status* out_status) {
177 tensorflow::CallableOptions callable_options_proto;
178 if (callable_options != nullptr &&
179 !callable_options_proto.ParseFromArray(callable_options->data,
180 callable_options->length)) {
181 Set_TF_Status_from_Status(
182 out_status,
183 errors::InvalidArgument("Unparseable CallableOptions proto"));
184 return;
185 }
186 tensorflow::Session::CallableHandle handle;
187 Status s = session->MakeCallable(callable_options_proto, &handle);
188 if (!s.ok()) {
189 Set_TF_Status_from_Status(out_status, s);
190 return;
191 }
192 *out_handle = handle;
193 }
194 } // namespace
195
TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * status)196 void TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession* session,
197 const TF_Buffer* callable_options,
198 int64_t* out_handle, TF_Status* status) {
199 MakeCallableHelper(session->session, callable_options, out_handle, status);
200 }
TF_SessionMakeCallable(TF_Session * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * status)201 void TF_SessionMakeCallable(TF_Session* session,
202 const TF_Buffer* callable_options,
203 int64_t* out_handle, TF_Status* status) {
204 MakeCallableHelper(session->session, callable_options, out_handle, status);
205 }
206
207 namespace {
RunCallableHelper(tensorflow::Session * session,int64_t handle,PyObject * feed_values,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_metadata)208 void RunCallableHelper(tensorflow::Session* session, int64_t handle,
209 PyObject* feed_values, TF_Status* out_status,
210 PyObjectVector* out_values, TF_Buffer* run_metadata) {
211 // Convert feed values to a vector of tensorflow::Tensor objects.
212 std::vector<Tensor> input_tensors;
213 Status s;
214 {
215 feed_values =
216 PySequence_Fast(feed_values, "feed_values must be a sequence");
217 if (feed_values == nullptr) return;
218 Safe_PyObjectPtr feed_values_holder(make_safe(feed_values));
219 Py_ssize_t len = PySequence_Fast_GET_SIZE(feed_values);
220 input_tensors.reserve(len);
221 for (Py_ssize_t i = 0; i < len; ++i) {
222 PyObject* elem = PySequence_Fast_GET_ITEM(feed_values, i);
223 if (!elem) {
224 Set_TF_Status_from_Status(
225 out_status, errors::Internal("Could not get feed value ", i));
226 return;
227 }
228 Tensor t;
229 s = NdarrayToTensor(elem, &t);
230 if (!s.ok()) {
231 Set_TF_Status_from_Status(out_status, s);
232 return;
233 }
234 input_tensors.push_back(std::move(t));
235 }
236 }
237
238 RunMetadata run_metadata_proto;
239
240 // Run the callable.
241 std::vector<Tensor> output_tensors;
242 Py_BEGIN_ALLOW_THREADS;
243 s = session->RunCallable(handle, input_tensors, &output_tensors,
244 &run_metadata_proto);
245 Py_END_ALLOW_THREADS;
246
247 if (!s.ok()) {
248 Set_TF_Status_from_Status(out_status, s);
249 return;
250 }
251
252 // If requested, serialize the RunMetadata to pass it back to the caller.
253 if (run_metadata != nullptr) {
254 s = MessageToBuffer(run_metadata_proto, run_metadata);
255 if (!s.ok()) {
256 Set_TF_Status_from_Status(out_status, s);
257 return;
258 }
259 }
260
261 // Convert results to NumPy arrays. Since this can fail, stage the
262 // results via a safe container that takes care of decreasing the
263 // reference count on failure.
264 std::vector<Safe_PyObjectPtr> py_outputs_safe;
265 py_outputs_safe.reserve(output_tensors.size());
266 for (const Tensor& output : output_tensors) {
267 PyObject* py_array;
268 s = TensorToNdarray(output, &py_array);
269 if (!s.ok()) {
270 Set_TF_Status_from_Status(out_status, s);
271 return;
272 }
273 py_outputs_safe.push_back(
274 make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array))));
275 }
276
277 // If we reach this point, we have successfully built a list of objects
278 // so we can release them from the safe container.
279 out_values->reserve(py_outputs_safe.size());
280 for (auto& output : py_outputs_safe) {
281 out_values->push_back(output.release());
282 }
283 }
284 } // namespace
285
TF_DeprecatedSessionRunCallable(TF_DeprecatedSession * session,int64_t handle,PyObject * feed_values,PyObjectVector * out_values,TF_Buffer * run_metadata,TF_Status * status)286 void TF_DeprecatedSessionRunCallable(TF_DeprecatedSession* session,
287 int64_t handle, PyObject* feed_values,
288 PyObjectVector* out_values,
289 TF_Buffer* run_metadata,
290 TF_Status* status) {
291 RunCallableHelper(session->session, handle, feed_values, status, out_values,
292 run_metadata);
293 ClearDecrefCache();
294 }
TF_SessionRunCallable(TF_Session * session,int64_t handle,PyObject * feed_values,PyObjectVector * out_values,TF_Buffer * run_metadata,TF_Status * status)295 void TF_SessionRunCallable(TF_Session* session, int64_t handle,
296 PyObject* feed_values, PyObjectVector* out_values,
297 TF_Buffer* run_metadata, TF_Status* status) {
298 RunCallableHelper(session->session, handle, feed_values, status, out_values,
299 run_metadata);
300 ClearDecrefCache();
301 }
302
TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession * session,int64_t handle,TF_Status * status)303 void TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession* session,
304 int64_t handle, TF_Status* status) {
305 Set_TF_Status_from_Status(status, session->session->ReleaseCallable(handle));
306 }
TF_SessionReleaseCallable(TF_Session * session,int64_t handle,TF_Status * status)307 void TF_SessionReleaseCallable(TF_Session* session, int64_t handle,
308 TF_Status* status) {
309 Set_TF_Status_from_Status(status, session->session->ReleaseCallable(handle));
310 }
311
312 // Wrapper for TF_PRunSetup that converts the arguments to appropriate types.
313 // If *out_status is OK, the caller becomes the owner of *out_handle.
TF_PRunSetup_wrapper(TF_DeprecatedSession * session,const NameVector & input_names,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,const char ** out_handle)314 void TF_PRunSetup_wrapper(TF_DeprecatedSession* session,
315 const NameVector& input_names,
316 const NameVector& output_names,
317 const NameVector& target_nodes, TF_Status* out_status,
318 const char** out_handle) {
319 Py_BEGIN_ALLOW_THREADS;
320 TF_PRunSetup(
321 session, const_cast<const char**>(input_names.data()), input_names.size(),
322 const_cast<const char**>(output_names.data()), output_names.size(),
323 const_cast<const char**>(target_nodes.data()), target_nodes.size(),
324 out_handle, out_status);
325 Py_END_ALLOW_THREADS;
326 }
327
328 // Wrapper for TF_PRun that converts the arguments to appropriate types.
329 // If *out_status is OK, the caller becomes the owner of the PyObjects
330 // in *out_values.
TF_PRun_wrapper(TF_DeprecatedSession * session,const char * handle,PyObject * feed_dict,const NameVector & output_names,TF_Status * out_status,PyObjectVector * out_values)331 void TF_PRun_wrapper(TF_DeprecatedSession* session, const char* handle,
332 PyObject* feed_dict, const NameVector& output_names,
333 TF_Status* out_status, PyObjectVector* out_values) {
334 TF_Run_wrapper_helper(session, handle, nullptr, feed_dict, output_names,
335 NameVector(), out_status, out_values, nullptr);
336 ClearDecrefCache();
337 }
338
339 // Wrapper for TF_Reset that converts the string vectors to character arrays.
TF_Reset_wrapper(const TF_SessionOptions * opt,const NameVector & containers,TF_Status * status)340 void TF_Reset_wrapper(const TF_SessionOptions* opt,
341 const NameVector& containers, TF_Status* status) {
342 TF_Reset(opt, const_cast<const char**>(containers.data()), containers.size(),
343 status);
344 }
345
TF_SessionRun_wrapper_helper(TF_Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,TF_Buffer * run_metadata,TF_Status * out_status,std::vector<PyObject * > * py_outputs)346 void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
347 const TF_Buffer* run_options,
348 const std::vector<TF_Output>& inputs,
349 const std::vector<PyObject*>& input_ndarrays,
350 const std::vector<TF_Output>& outputs,
351 const std::vector<TF_Operation*>& targets,
352 TF_Buffer* run_metadata,
353 TF_Status* out_status,
354 std::vector<PyObject*>* py_outputs) {
355 DCHECK_EQ(inputs.size(), input_ndarrays.size());
356 DCHECK(py_outputs != nullptr);
357 DCHECK(py_outputs->empty());
358 Status s;
359
360 // Convert input ndarray PyObjects to TF_Tensors. We maintain a continuous
361 // array of TF_Tensor*s as well as scoped containers to make sure they're
362 // cleaned up properly.
363 //
364 // Memory management:
365 // NdarrayToTensor() creates a new ndarray PyObject from the input
366 // ndarray. We manage the new ndarray's lifetime in order to keep the
367 // underlying data buffer alive (the new ndarray also guarantees a contiguous
368 // data buffer). The new ndarray's data buffer is used to create the
369 // corresponding TF_Tensor. The TF_Tensor's deallocator will queue the new
370 // ndarray to be decref'd by the next ClearDecrefCache() call (we can't call
371 // Py_DECREF in the deallocator directly because the GIL must be held).
372 //
373 // Note that TF_Tensor may directly delegate its data and deallocator to a
374 // TensorBuffer, which may outlive the TF_Tensor (e.g. if the tensor gets
375 // queued or assigned to a variable).
376 TF_TensorVector input_vals;
377 std::vector<Safe_TF_TensorPtr> input_vals_safe;
378 for (PyObject* ndarray : input_ndarrays) {
379 input_vals_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
380 s = NdarrayToTensor(nullptr, ndarray, &input_vals_safe.back());
381 if (!s.ok()) {
382 Set_TF_Status_from_Status(out_status, s);
383 return;
384 }
385 input_vals.push_back(input_vals_safe.back().get());
386 }
387
388 // Allocate space for output TF_Tensor*s
389 TF_TensorVector output_vals(outputs.size());
390
391 // Clear up any unused memory leftover from previous runs
392 ClearDecrefCache();
393
394 // Call TF_SessionRun() (and release GIL during execution)
395 Py_BEGIN_ALLOW_THREADS;
396 if (handle == nullptr) {
397 TF_SessionRun(session, run_options, inputs.data(), input_vals.data(),
398 inputs.size(), outputs.data(), output_vals.data(),
399 outputs.size(), targets.data(), targets.size(), run_metadata,
400 out_status);
401 } else {
402 TF_SessionPRun(session, handle, inputs.data(), input_vals.data(),
403 inputs.size(), outputs.data(), output_vals.data(),
404 outputs.size(), targets.data(), targets.size(), out_status);
405 }
406 Py_END_ALLOW_THREADS;
407
408 // Create scoped containers for output tensors
409 std::vector<Safe_TF_TensorPtr> output_vals_safe;
410 for (TF_Tensor* output : output_vals) {
411 output_vals_safe.emplace_back(make_safe(output));
412 }
413
414 // Convert outputs to ndarrays (in scoped containers)
415 std::vector<Safe_PyObjectPtr> py_outputs_safe;
416 for (size_t i = 0; i < outputs.size(); ++i) {
417 PyObject* py_array;
418 s = TF_TensorToPyArray(std::move(output_vals_safe[i]), &py_array);
419 if (!s.ok()) {
420 Set_TF_Status_from_Status(out_status, s);
421 return;
422 }
423 py_outputs_safe.emplace_back(
424 make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array))));
425 }
426
427 // If we reach this point, we have successfully built a list of objects so we
428 // can release them from the safe container into the return vector.
429 for (size_t i = 0; i < outputs.size(); ++i) {
430 py_outputs->push_back(py_outputs_safe[i].release());
431 }
432 }
433
TF_SessionRun_wrapper(TF_Session * session,const TF_Buffer * run_options,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,TF_Buffer * run_metadata,TF_Status * out_status,std::vector<PyObject * > * py_outputs)434 void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options,
435 const std::vector<TF_Output>& inputs,
436 const std::vector<PyObject*>& input_ndarrays,
437 const std::vector<TF_Output>& outputs,
438 const std::vector<TF_Operation*>& targets,
439 TF_Buffer* run_metadata, TF_Status* out_status,
440 std::vector<PyObject*>* py_outputs) {
441 TF_SessionRun_wrapper_helper(session, nullptr, run_options, inputs,
442 input_ndarrays, outputs, targets, run_metadata,
443 out_status, py_outputs);
444 // Release any unused ndarray references (see memory management comment in
445 // TF_SessionRun_wrapper_helper)
446 ClearDecrefCache();
447 }
448
EqualGraphDefWrapper(const string & actual,const string & expected)449 string EqualGraphDefWrapper(const string& actual, const string& expected) {
450 GraphDef actual_def;
451 if (!actual_def.ParseFromString(actual)) {
452 return "actual is not a valid serialized GraphDef";
453 }
454 GraphDef expected_def;
455 if (!expected_def.ParseFromString(expected)) {
456 return "expected is not a valid serialized GraphDef";
457 }
458 string diff;
459 return EqualGraphDef(actual_def, expected_def, &diff) ? "" : diff;
460 }
461
EqualAttrValueWrapper(const string & actual,const string & expected)462 string EqualAttrValueWrapper(const string& actual, const string& expected) {
463 AttrValue actual_attr_value;
464 if (!actual_attr_value.ParseFromString(actual)) {
465 return "actual is not a valid serialized AttrValue";
466 }
467
468 AttrValue expected_attr_value;
469 if (!expected_attr_value.ParseFromString(expected)) {
470 return "expected is not a valid serialized AttrValue";
471 }
472
473 string diff;
474 if (!AreAttrValuesEqual(actual_attr_value, expected_attr_value)) {
475 diff = strings::Printf(
476 "Actual AttrValue %s does not match Expected AttrValue %s.",
477 SummarizeAttrValue(actual_attr_value).c_str(),
478 SummarizeAttrValue(expected_attr_value).c_str());
479 }
480 return diff;
481 }
482
483 // Return value set to 6 inlined elements so it fits in a 64-byte cache line.
TF_GraphGetTensorShapeHelper(TF_Graph * graph,TF_Output output,TF_Status * out_status,bool * unknown_shape)484 tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper(
485 TF_Graph* graph, TF_Output output, TF_Status* out_status,
486 bool* unknown_shape) {
487 // Allocate a single variable for holding the result for RVO.
488 tensorflow::gtl::InlinedVector<int64_t, 6> result;
489 *unknown_shape = false;
490 int num_dims = TF_GraphGetTensorNumDims(graph, output, out_status);
491 if (TF_GetCode(out_status) != TF_OK) {
492 return result;
493 }
494 // If shape is unknown, set boolean and return.
495 if (num_dims == -1) {
496 *unknown_shape = true;
497 return result;
498 }
499
500 // If shape is a scalar, avoid another C call and just return {}.
501 if (num_dims == 0) {
502 return result;
503 }
504
505 result.resize(num_dims);
506 TF_GraphGetTensorShape(graph, output, result.data(), num_dims, out_status);
507 return result;
508 }
509
TF_SessionPRunSetup_wrapper(TF_Session * session,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,const char ** out_handle,TF_Status * out_status)510 void TF_SessionPRunSetup_wrapper(TF_Session* session,
511 const std::vector<TF_Output>& inputs,
512 const std::vector<TF_Output>& outputs,
513 const std::vector<TF_Operation*>& targets,
514 const char** out_handle,
515 TF_Status* out_status) {
516 // Call TF_SessionPRunSetup() (and release GIL during execution)
517 Py_BEGIN_ALLOW_THREADS;
518 TF_SessionPRunSetup(session, inputs.data(), inputs.size(), outputs.data(),
519 outputs.size(), targets.data(), targets.size(),
520 out_handle, out_status);
521 Py_END_ALLOW_THREADS;
522 }
523
TF_SessionPRun_wrapper(TF_Session * session,const char * handle,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,TF_Status * out_status,std::vector<PyObject * > * py_outputs)524 void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
525 const std::vector<TF_Output>& inputs,
526 const std::vector<PyObject*>& input_ndarrays,
527 const std::vector<TF_Output>& outputs,
528 TF_Status* out_status,
529 std::vector<PyObject*>* py_outputs) {
530 const std::vector<TF_Operation*> targets;
531 TF_SessionRun_wrapper_helper(session, handle,
532 nullptr, // run_options
533 inputs, input_ndarrays, outputs, targets,
534 nullptr, // run_metadata
535 out_status, py_outputs);
536 // Release any unused ndarray references (see memory management comment in
537 // TF_SessionRun_wrapper_helper)
538 ClearDecrefCache();
539 }
540
GetOperationInputs(TF_Operation * oper)541 std::vector<TF_Output> GetOperationInputs(TF_Operation* oper) {
542 int num_inputs = TF_OperationNumInputs(oper);
543 std::vector<TF_Output> inputs(num_inputs);
544 TF_OperationAllInputs(oper, inputs.data(), inputs.size());
545 return inputs;
546 }
547
TF_OperationGetControlInputs_wrapper(TF_Operation * oper)548 std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
549 TF_Operation* oper) {
550 std::vector<TF_Operation*> control_inputs(TF_OperationNumControlInputs(oper));
551 TF_OperationGetControlInputs(oper, control_inputs.data(),
552 control_inputs.size());
553 return control_inputs;
554 }
555
TF_OperationGetControlOutputs_wrapper(TF_Operation * oper)556 std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper(
557 TF_Operation* oper) {
558 std::vector<TF_Operation*> control_outputs(
559 TF_OperationNumControlOutputs(oper));
560 TF_OperationGetControlOutputs(oper, control_outputs.data(),
561 control_outputs.size());
562 return control_outputs;
563 }
564
TF_OperationOutputConsumers_wrapper(TF_Output oper_out)565 std::vector<const char*> TF_OperationOutputConsumers_wrapper(
566 TF_Output oper_out) {
567 int num_consumers = TF_OperationOutputNumConsumers(oper_out);
568 std::vector<TF_Input> consumers(num_consumers);
569 TF_OperationOutputConsumers(oper_out, consumers.data(), num_consumers);
570
571 std::vector<const char*> consumer_names(num_consumers);
572 for (int i = 0; i < num_consumers; ++i) {
573 consumer_names[i] = TF_OperationName(consumers[i].oper);
574 }
575 return consumer_names;
576 }
577
TF_GraphToFunction_wrapper(const TF_Graph * fn_body,const char * fn_name,bool append_hash_to_fn_name,const std::vector<TF_Operation * > * opers,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const NameVector & output_names,const std::vector<TF_Operation * > * control_outputs,const NameVector & control_output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * out_status)578 TF_Function* TF_GraphToFunction_wrapper(
579 const TF_Graph* fn_body, const char* fn_name, bool append_hash_to_fn_name,
580 const std::vector<TF_Operation*>* opers,
581 const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs,
582 const NameVector& output_names,
583 const std::vector<TF_Operation*>* control_outputs,
584 const NameVector& control_output_names, const TF_FunctionOptions* opts,
585 const char* description, TF_Status* out_status) {
586 if (!output_names.empty() && output_names.size() != outputs.size()) {
587 Set_TF_Status_from_Status(
588 out_status,
589 errors::InvalidArgument(
590 "output names must be either empty or equal in size to outputs. ",
591 "output names size = ", output_names.size(),
592 " outputs size = ", outputs.size()));
593 return nullptr;
594 }
595
596 int nopers = -1;
597 const TF_Operation* const* opers_array = nullptr;
598 if (opers != nullptr) {
599 nopers = opers->size();
600 opers_array = opers->data();
601 }
602
603 const char** output_names_ptr =
604 output_names.empty() ? nullptr
605 : const_cast<const char**>(output_names.data());
606
607 const char** control_output_names_ptr =
608 control_output_names.empty()
609 ? nullptr
610 : const_cast<const char**>(control_output_names.data());
611
612 return TF_GraphToFunctionWithControlOutputs(
613 fn_body, fn_name, append_hash_to_fn_name, nopers, opers_array,
614 inputs.size(), inputs.data(), outputs.size(), outputs.data(),
615 output_names_ptr,
616 control_outputs == nullptr ? 0 : control_outputs->size(),
617 control_outputs == nullptr ? nullptr : control_outputs->data(),
618 control_output_names_ptr, opts, description, out_status);
619 }
620
TF_GraphSetOutputHandleShapesAndTypes_wrapper(TF_Graph * graph,TF_Output output,const std::vector<std::vector<int64_t>> & shapes,const std::vector<int> & ranks,const std::vector<TF_DataType> & types,TF_Status * status)621 void TF_GraphSetOutputHandleShapesAndTypes_wrapper(
622 TF_Graph* graph, TF_Output output,
623 const std::vector<std::vector<int64_t>>& shapes,
624 const std::vector<int>& ranks, const std::vector<TF_DataType>& types,
625 TF_Status* status) {
626 std::vector<const int64_t*> shapes_pointers(shapes.size());
627 for (int i = 0; i < shapes.size(); ++i) {
628 shapes_pointers[i] = ranks[i] <= 0 ? nullptr : &shapes[i][0];
629 }
630 TF_GraphSetOutputHandleShapesAndTypes(graph, output, shapes.size(),
631 shapes_pointers.data(), ranks.data(),
632 types.data(), status);
633 }
634
CreatePlaceholder(TF_Graph * graph,TF_Status * s,string && name,TF_DataType dtype,TF_Output * output)635 void CreatePlaceholder(TF_Graph* graph, TF_Status* s, string&& name,
636 TF_DataType dtype, TF_Output* output) {
637 TF_OperationDescription* desc =
638 TF_NewOperation(graph, "Placeholder", name.data());
639 TF_SetAttrType(desc, "dtype", dtype);
640 TF_Operation* op = TF_FinishOperation(desc, s);
641 output->oper = op;
642 output->index = 0;
643 }
644
TF_CreatePlaceholders(TF_Graph * graph,PyObject * dtypes,const char * prefix,TF_Status * status)645 std::vector<TF_Output> TF_CreatePlaceholders(TF_Graph* graph, PyObject* dtypes,
646 const char* prefix,
647 TF_Status* status) {
648 std::vector<TF_Output> outputs;
649 dtypes = PySequence_Fast(dtypes, "dtypes must be a sequence");
650 if (dtypes == nullptr) {
651 Set_TF_Status_from_Status(status, errors::Internal("dtypes is nullptr"));
652 return outputs;
653 }
654 Safe_PyObjectPtr dtypes_holder(make_safe(dtypes));
655 Py_ssize_t len = PySequence_Fast_GET_SIZE(dtypes);
656 outputs.reserve(len);
657 for (size_t i = 0; i < len; i++) {
658 PyObject* dtype = PySequence_Fast_GET_ITEM(dtypes, i);
659 if (!dtype) {
660 Set_TF_Status_from_Status(status,
661 errors::Internal("Could not get dtype ", i));
662 return outputs;
663 }
664 #if PY_MAJOR_VERSION >= 3
665 TF_DataType tf_datatype = static_cast<TF_DataType>(PyLong_AsLong(dtype));
666 #else
667 TF_DataType tf_datatype = static_cast<TF_DataType>(PyInt_AsLong(dtype));
668 #endif
669 outputs.push_back(TF_Output());
670 CreatePlaceholder(graph, status, strings::StrCat(prefix, i), tf_datatype,
671 &outputs.back());
672 if (!status->status.ok()) break;
673 }
674 return outputs;
675 }
676
TF_GraphSetTensorShape_wrapper(TF_Graph * graph,TF_Output output,const std::vector<int64_t> & dims,bool unknown_shape,TF_Status * status)677 void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
678 const std::vector<int64_t>& dims,
679 bool unknown_shape, TF_Status* status) {
680 if (unknown_shape) {
681 TF_GraphSetTensorShape(graph, output, nullptr, -1, status);
682 return;
683 }
684 TF_GraphSetTensorShape(graph, output, dims.data(), dims.size(), status);
685 }
686
TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(TF_ImportGraphDefResults * results)687 std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
688 TF_ImportGraphDefResults* results) {
689 int num_missing_unused_input_mappings;
690 const char** src_names;
691 int* src_indexes;
692 TF_ImportGraphDefResultsMissingUnusedInputMappings(
693 results, &num_missing_unused_input_mappings, &src_names, &src_indexes);
694 std::vector<string> input_strs(num_missing_unused_input_mappings);
695 for (int i = 0; i < num_missing_unused_input_mappings; ++i) {
696 input_strs[i] = TensorId(src_names[i], src_indexes[i]).ToString();
697 }
698 return input_strs;
699 }
700
TF_TryEvaluateConstant_wrapper(TF_Graph * graph,TF_Output output,TF_Status * status)701 PyObject* TF_TryEvaluateConstant_wrapper(TF_Graph* graph, TF_Output output,
702 TF_Status* status) {
703 TF_Tensor* result_tensor;
704 bool evaluated =
705 TF_TryEvaluateConstant(graph, output, &result_tensor, status);
706 if (!evaluated || TF_GetCode(status) != TF_OK) Py_RETURN_NONE;
707
708 Safe_TF_TensorPtr safe_result_tensor(result_tensor);
709 PyObject* out;
710 Status s = TF_TensorToPyArray(std::move(safe_result_tensor), &out);
711 Set_TF_Status_from_Status(status, s);
712 if (!s.ok()) Py_RETURN_NONE;
713 return PyArray_Return(reinterpret_cast<PyArrayObject*>(out));
714 }
715
716 } // namespace tensorflow
717