1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/python/client/tf_session_helper.h"
17
18 #include <cstring>
19
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/c_api_internal.h"
22 #include "tensorflow/c/tf_status_helper.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/log_memory.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/graph/tensor_id.h"
29 #include "tensorflow/core/lib/core/coding.h"
30 #include "tensorflow/core/lib/strings/stringprintf.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/equal_graph_def.h"
33 #include "tensorflow/python/client/session_ref.h"
34 #include "tensorflow/python/lib/core/ndarray_tensor.h"
35 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
36 #include "tensorflow/python/lib/core/safe_ptr.h"
37
38 namespace tensorflow {
39
40 namespace {
41
42 static const char* kFeedDictErrorMsg =
43 "feed_dict must be a dictionary mapping strings to NumPy arrays.";
44 } // end namespace
45
TF_NewSessionRef(TF_Graph * graph,const TF_SessionOptions * opts,TF_Status * status)46 TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts,
47 TF_Status* status) {
48 TF_Session* tf_session = TF_NewSession(graph, opts, status);
49 if (tf_session == nullptr) {
50 return nullptr;
51 }
52
53 Session* session = reinterpret_cast<Session*>(tf_session->session);
54 SessionRef* session_ref = new SessionRef(session);
55 tf_session->session = session_ref;
56 return tf_session;
57 }
58
TF_Run_wrapper_helper(TF_DeprecatedSession * session,const char * handle,const TF_Buffer * run_options,PyObject * feed_dict,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_outputs)59 void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
60 const TF_Buffer* run_options, PyObject* feed_dict,
61 const NameVector& output_names,
62 const NameVector& target_nodes,
63 TF_Status* out_status, PyObjectVector* out_values,
64 TF_Buffer* run_outputs) {
65 // 1. Convert the feed inputs to the appropriate form for TF_Run.
66 if (!PyDict_Check(feed_dict)) {
67 Set_TF_Status_from_Status(out_status,
68 errors::InvalidArgument(kFeedDictErrorMsg));
69 return;
70 }
71
72 NameVector input_names;
73 std::vector<Safe_TF_TensorPtr> inputs_safe; // Used to delete tensors.
74 TF_TensorVector inputs_unsafe; // Used to contain the arg to TF_Run.
75
76 PyObject* key;
77 PyObject* value;
78 Py_ssize_t pos = 0;
79 int index = 0;
80 Status s;
81
82 while (PyDict_Next(feed_dict, &pos, &key, &value)) {
83 char* key_string = PyBytes_AsString(key);
84 if (!key_string) {
85 Set_TF_Status_from_Status(out_status,
86 errors::InvalidArgument(kFeedDictErrorMsg));
87 return;
88 }
89 input_names.push_back(key_string);
90
91 inputs_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
92 s = PyArrayToTF_Tensor(value, &inputs_safe.back());
93 if (!s.ok()) {
94 Set_TF_Status_from_Status(out_status, s);
95 return;
96 }
97 inputs_unsafe.push_back(inputs_safe.back().get());
98 ++index;
99 }
100
101 // 2. Allocate a container for the output data.
102 TF_TensorVector outputs(output_names.size());
103
104 // In case any tensors were leftover from previous runs we might as well clear
105 // them here.
106 ClearDecrefCache();
107
108 // 3. Actually call TF_Run().
109 Py_BEGIN_ALLOW_THREADS;
110 if (handle == nullptr) {
111 TF_Run(session, run_options, input_names.data(), inputs_unsafe.data(),
112 input_names.size(), const_cast<const char**>(output_names.data()),
113 outputs.data(), output_names.size(),
114 const_cast<const char**>(target_nodes.data()), target_nodes.size(),
115 run_outputs, out_status);
116 } else {
117 TF_PRun(session, handle, input_names.data(), inputs_unsafe.data(),
118 input_names.size(), const_cast<const char**>(output_names.data()),
119 outputs.data(), output_names.size(),
120 const_cast<const char**>(target_nodes.data()), target_nodes.size(),
121 out_status);
122 }
123
124 Py_END_ALLOW_THREADS;
125
126 // Decref any numpy arrays we are not using anymore.
127 ClearDecrefCache();
128
129 if (TF_GetCode(out_status) != TF_OK) {
130 return;
131 }
132
133 // 4. We now own the fetched tensors, so set up a safe container to
134 // delete them when we exit this scope.
135 std::vector<Safe_TF_TensorPtr> tf_outputs_safe;
136 for (const auto& output : outputs) {
137 tf_outputs_safe.emplace_back(make_safe(output));
138 }
139
140 // 5. Convert the fetched tensors into numpy ndarrays. Store them in a safe
141 // container so that we do not leak
142 std::vector<Safe_PyObjectPtr> py_outputs_safe;
143 for (size_t i = 0; i < output_names.size(); ++i) {
144 PyObject* py_array;
145 s = TF_TensorToPyArray(std::move(tf_outputs_safe[i]), &py_array);
146 if (!s.ok()) {
147 Set_TF_Status_from_Status(out_status, s);
148 return;
149 }
150 py_outputs_safe.emplace_back(
151 make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array))));
152 }
153
154 // 6. If we reach this point, we have successfully built a list of objects
155 // so we can release them from the safe container.
156 for (auto& output : py_outputs_safe) {
157 out_values->push_back(output.release());
158 }
159 }
160
161 // Wrapper for TF_Run that converts the arguments to appropriate types.
162 // If *out_status is OK, the caller becomes the owner of the PyObjects
163 // in *out_values.
TF_Run_wrapper(TF_DeprecatedSession * session,const TF_Buffer * run_options,PyObject * feed_dict,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_outputs)164 void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options,
165 PyObject* feed_dict, const NameVector& output_names,
166 const NameVector& target_nodes, TF_Status* out_status,
167 PyObjectVector* out_values, TF_Buffer* run_outputs) {
168 TF_Run_wrapper_helper(session, nullptr, run_options, feed_dict, output_names,
169 target_nodes, out_status, out_values, run_outputs);
170 ClearDecrefCache();
171 }
172
173 namespace {
MakeCallableHelper(tensorflow::Session * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * out_status)174 void MakeCallableHelper(tensorflow::Session* session,
175 const TF_Buffer* callable_options, int64_t* out_handle,
176 TF_Status* out_status) {
177 tensorflow::CallableOptions callable_options_proto;
178 if (callable_options != nullptr &&
179 !callable_options_proto.ParseFromArray(callable_options->data,
180 callable_options->length)) {
181 Set_TF_Status_from_Status(
182 out_status,
183 errors::InvalidArgument("Unparseable CallableOptions proto"));
184 return;
185 }
186 tensorflow::Session::CallableHandle handle;
187 Status s = session->MakeCallable(callable_options_proto, &handle);
188 if (!s.ok()) {
189 Set_TF_Status_from_Status(out_status, s);
190 return;
191 }
192 *out_handle = handle;
193 }
194 } // namespace
195
TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * status)196 void TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession* session,
197 const TF_Buffer* callable_options,
198 int64_t* out_handle, TF_Status* status) {
199 MakeCallableHelper(session->session, callable_options, out_handle, status);
200 }
TF_SessionMakeCallable(TF_Session * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * status)201 void TF_SessionMakeCallable(TF_Session* session,
202 const TF_Buffer* callable_options,
203 int64_t* out_handle, TF_Status* status) {
204 MakeCallableHelper(session->session, callable_options, out_handle, status);
205 }
206
207 namespace {
RunCallableHelper(tensorflow::Session * session,int64_t handle,PyObject * feed_values,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_metadata)208 void RunCallableHelper(tensorflow::Session* session, int64_t handle,
209 PyObject* feed_values, TF_Status* out_status,
210 PyObjectVector* out_values, TF_Buffer* run_metadata) {
211 // Convert feed values to a vector of tensorflow::Tensor objects.
212 std::vector<Tensor> input_tensors;
213 Status s;
214 {
215 feed_values =
216 PySequence_Fast(feed_values, "feed_values must be a sequence");
217 if (feed_values == nullptr) return;
218 Safe_PyObjectPtr feed_values_holder(make_safe(feed_values));
219 Py_ssize_t len = PySequence_Fast_GET_SIZE(feed_values);
220 input_tensors.reserve(len);
221 for (Py_ssize_t i = 0; i < len; ++i) {
222 PyObject* elem = PySequence_Fast_GET_ITEM(feed_values, i);
223 if (!elem) {
224 Set_TF_Status_from_Status(
225 out_status, errors::Internal("Could not get feed value ", i));
226 return;
227 }
228 Tensor t;
229 s = NdarrayToTensor(elem, &t);
230 if (!s.ok()) {
231 Set_TF_Status_from_Status(out_status, s);
232 return;
233 }
234 input_tensors.push_back(std::move(t));
235 }
236 }
237
238 // Allocate a RunMetadata protobuf object to receive the metadata,
239 // if the caller is expecting any.
240 std::unique_ptr<RunMetadata> run_metadata_proto;
241 if (run_metadata != nullptr) {
242 run_metadata_proto.reset(new RunMetadata);
243 }
244
245 // Run the callable.
246 std::vector<Tensor> output_tensors;
247 Py_BEGIN_ALLOW_THREADS;
248 s = session->RunCallable(handle, input_tensors, &output_tensors,
249 run_metadata_proto.get());
250 Py_END_ALLOW_THREADS;
251
252 if (!s.ok()) {
253 Set_TF_Status_from_Status(out_status, s);
254 return;
255 }
256
257 // If requested, serialize the RunMetadata to pass it back to the caller.
258 if (run_metadata != nullptr) {
259 s = MessageToBuffer(*run_metadata_proto, run_metadata);
260 if (!s.ok()) {
261 Set_TF_Status_from_Status(out_status, s);
262 return;
263 }
264 }
265
266 // Convert results to NumPy arrays. Since this can fail, stage the
267 // results via a safe container that takes care of decreasing the
268 // reference count on failure.
269 std::vector<Safe_PyObjectPtr> py_outputs_safe;
270 py_outputs_safe.reserve(output_tensors.size());
271 for (const Tensor& output : output_tensors) {
272 PyObject* py_array;
273 s = TensorToNdarray(output, &py_array);
274 if (!s.ok()) {
275 Set_TF_Status_from_Status(out_status, s);
276 return;
277 }
278 py_outputs_safe.push_back(
279 make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array))));
280 }
281
282 // If we reach this point, we have successfully built a list of objects
283 // so we can release them from the safe container.
284 out_values->reserve(py_outputs_safe.size());
285 for (auto& output : py_outputs_safe) {
286 out_values->push_back(output.release());
287 }
288 }
289 } // namespace
290
TF_DeprecatedSessionRunCallable(TF_DeprecatedSession * session,int64_t handle,PyObject * feed_values,PyObjectVector * out_values,TF_Buffer * run_metadata,TF_Status * status)291 void TF_DeprecatedSessionRunCallable(TF_DeprecatedSession* session,
292 int64_t handle, PyObject* feed_values,
293 PyObjectVector* out_values,
294 TF_Buffer* run_metadata,
295 TF_Status* status) {
296 RunCallableHelper(session->session, handle, feed_values, status, out_values,
297 run_metadata);
298 ClearDecrefCache();
299 }
TF_SessionRunCallable(TF_Session * session,int64_t handle,PyObject * feed_values,PyObjectVector * out_values,TF_Buffer * run_metadata,TF_Status * status)300 void TF_SessionRunCallable(TF_Session* session, int64_t handle,
301 PyObject* feed_values, PyObjectVector* out_values,
302 TF_Buffer* run_metadata, TF_Status* status) {
303 RunCallableHelper(session->session, handle, feed_values, status, out_values,
304 run_metadata);
305 ClearDecrefCache();
306 }
307
TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession * session,int64_t handle,TF_Status * status)308 void TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession* session,
309 int64_t handle, TF_Status* status) {
310 Set_TF_Status_from_Status(status, session->session->ReleaseCallable(handle));
311 }
TF_SessionReleaseCallable(TF_Session * session,int64_t handle,TF_Status * status)312 void TF_SessionReleaseCallable(TF_Session* session, int64_t handle,
313 TF_Status* status) {
314 Set_TF_Status_from_Status(status, session->session->ReleaseCallable(handle));
315 }
316
317 // Wrapper for TF_PRunSetup that converts the arguments to appropriate types.
318 // If *out_status is OK, the caller becomes the owner of *out_handle.
TF_PRunSetup_wrapper(TF_DeprecatedSession * session,const NameVector & input_names,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,const char ** out_handle)319 void TF_PRunSetup_wrapper(TF_DeprecatedSession* session,
320 const NameVector& input_names,
321 const NameVector& output_names,
322 const NameVector& target_nodes, TF_Status* out_status,
323 const char** out_handle) {
324 Py_BEGIN_ALLOW_THREADS;
325 TF_PRunSetup(
326 session, const_cast<const char**>(input_names.data()), input_names.size(),
327 const_cast<const char**>(output_names.data()), output_names.size(),
328 const_cast<const char**>(target_nodes.data()), target_nodes.size(),
329 out_handle, out_status);
330 Py_END_ALLOW_THREADS;
331 }
332
333 // Wrapper for TF_PRun that converts the arguments to appropriate types.
334 // If *out_status is OK, the caller becomes the owner of the PyObjects
335 // in *out_values.
TF_PRun_wrapper(TF_DeprecatedSession * session,const char * handle,PyObject * feed_dict,const NameVector & output_names,TF_Status * out_status,PyObjectVector * out_values)336 void TF_PRun_wrapper(TF_DeprecatedSession* session, const char* handle,
337 PyObject* feed_dict, const NameVector& output_names,
338 TF_Status* out_status, PyObjectVector* out_values) {
339 TF_Run_wrapper_helper(session, handle, nullptr, feed_dict, output_names,
340 NameVector(), out_status, out_values, nullptr);
341 ClearDecrefCache();
342 }
343
344 // Wrapper for TF_Reset that converts the string vectors to character arrays.
TF_Reset_wrapper(const TF_SessionOptions * opt,const NameVector & containers,TF_Status * status)345 void TF_Reset_wrapper(const TF_SessionOptions* opt,
346 const NameVector& containers, TF_Status* status) {
347 TF_Reset(opt, const_cast<const char**>(containers.data()), containers.size(),
348 status);
349 }
350
TF_SessionRun_wrapper_helper(TF_Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,TF_Buffer * run_metadata,TF_Status * out_status,std::vector<PyObject * > * py_outputs)351 void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
352 const TF_Buffer* run_options,
353 const std::vector<TF_Output>& inputs,
354 const std::vector<PyObject*>& input_ndarrays,
355 const std::vector<TF_Output>& outputs,
356 const std::vector<TF_Operation*>& targets,
357 TF_Buffer* run_metadata,
358 TF_Status* out_status,
359 std::vector<PyObject*>* py_outputs) {
360 DCHECK_EQ(inputs.size(), input_ndarrays.size());
361 DCHECK(py_outputs != nullptr);
362 DCHECK(py_outputs->empty());
363 Status s;
364
365 // Convert input ndarray PyObjects to TF_Tensors. We maintain a continuous
366 // array of TF_Tensor*s as well as scoped containers to make sure they're
367 // cleaned up properly.
368 //
369 // Memory management:
370 // PyArrayToTF_Tensor() creates a new ndarray PyObject from the input
371 // ndarray. We manage the new ndarray's lifetime in order to keep the
372 // underlying data buffer alive (the new ndarray also guarantees a contiguous
373 // data buffer). The new ndarray's data buffer is used to create the
374 // corresponding TF_Tensor. The TF_Tensor's deallocator will queue the new
375 // ndarray to be decref'd by the next ClearDecrefCache() call (we can't call
376 // Py_DECREF in the deallocator directly because the GIL must be held).
377 //
378 // Note that TF_Tensor may directly delegate its data and deallocator to a
379 // TensorBuffer, which may outlive the TF_Tensor (e.g. if the tensor gets
380 // queued or assigned to a variable).
381 TF_TensorVector input_vals;
382 std::vector<Safe_TF_TensorPtr> input_vals_safe;
383 for (PyObject* ndarray : input_ndarrays) {
384 input_vals_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
385 s = PyArrayToTF_Tensor(ndarray, &input_vals_safe.back());
386 if (!s.ok()) {
387 Set_TF_Status_from_Status(out_status, s);
388 return;
389 }
390 input_vals.push_back(input_vals_safe.back().get());
391 }
392
393 // Allocate space for output TF_Tensor*s
394 TF_TensorVector output_vals(outputs.size());
395
396 // Clear up any unused memory leftover from previous runs
397 ClearDecrefCache();
398
399 // Call TF_SessionRun() (and release GIL during execution)
400 Py_BEGIN_ALLOW_THREADS;
401 if (handle == nullptr) {
402 TF_SessionRun(session, run_options, inputs.data(), input_vals.data(),
403 inputs.size(), outputs.data(), output_vals.data(),
404 outputs.size(), targets.data(), targets.size(), run_metadata,
405 out_status);
406 } else {
407 TF_SessionPRun(session, handle, inputs.data(), input_vals.data(),
408 inputs.size(), outputs.data(), output_vals.data(),
409 outputs.size(), targets.data(), targets.size(), out_status);
410 }
411 Py_END_ALLOW_THREADS;
412
413 // Create scoped containers for output tensors
414 std::vector<Safe_TF_TensorPtr> output_vals_safe;
415 for (TF_Tensor* output : output_vals) {
416 output_vals_safe.emplace_back(make_safe(output));
417 }
418
419 // Convert outputs to ndarrays (in scoped containers)
420 std::vector<Safe_PyObjectPtr> py_outputs_safe;
421 for (size_t i = 0; i < outputs.size(); ++i) {
422 PyObject* py_array;
423 s = TF_TensorToPyArray(std::move(output_vals_safe[i]), &py_array);
424 if (!s.ok()) {
425 Set_TF_Status_from_Status(out_status, s);
426 return;
427 }
428 py_outputs_safe.emplace_back(
429 make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array))));
430 }
431
432 // If we reach this point, we have successfully built a list of objects so we
433 // can release them from the safe container into the return vector.
434 for (size_t i = 0; i < outputs.size(); ++i) {
435 py_outputs->push_back(py_outputs_safe[i].release());
436 }
437 }
438
TF_SessionRun_wrapper(TF_Session * session,const TF_Buffer * run_options,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,TF_Buffer * run_metadata,TF_Status * out_status,std::vector<PyObject * > * py_outputs)439 void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options,
440 const std::vector<TF_Output>& inputs,
441 const std::vector<PyObject*>& input_ndarrays,
442 const std::vector<TF_Output>& outputs,
443 const std::vector<TF_Operation*>& targets,
444 TF_Buffer* run_metadata, TF_Status* out_status,
445 std::vector<PyObject*>* py_outputs) {
446 TF_SessionRun_wrapper_helper(session, nullptr, run_options, inputs,
447 input_ndarrays, outputs, targets, run_metadata,
448 out_status, py_outputs);
449 // Release any unused ndarray references (see memory management comment in
450 // TF_SessionRun_wrapper_helper)
451 ClearDecrefCache();
452 }
453
EqualGraphDefWrapper(const string & actual,const string & expected)454 string EqualGraphDefWrapper(const string& actual, const string& expected) {
455 GraphDef actual_def;
456 if (!actual_def.ParseFromString(actual)) {
457 return "actual is not a valid serialized GraphDef";
458 }
459 GraphDef expected_def;
460 if (!expected_def.ParseFromString(expected)) {
461 return "expected is not a valid serialized GraphDef";
462 }
463 string diff;
464 return EqualGraphDef(actual_def, expected_def, &diff) ? "" : diff;
465 }
466
EqualAttrValueWrapper(const string & actual,const string & expected)467 string EqualAttrValueWrapper(const string& actual, const string& expected) {
468 AttrValue actual_attr_value;
469 if (!actual_attr_value.ParseFromString(actual)) {
470 return "actual is not a valid serialized AttrValue";
471 }
472
473 AttrValue expected_attr_value;
474 if (!expected_attr_value.ParseFromString(expected)) {
475 return "expected is not a valid serialized AttrValue";
476 }
477
478 string diff;
479 if (!AreAttrValuesEqual(actual_attr_value, expected_attr_value)) {
480 diff = strings::Printf(
481 "Actual AttrValue %s does not match Expected AttrValue %s.",
482 SummarizeAttrValue(actual_attr_value).c_str(),
483 SummarizeAttrValue(expected_attr_value).c_str());
484 }
485 return diff;
486 }
487
488 // Return value set to 6 inlined elements so it fits in a 64-byte cache line.
TF_GraphGetTensorShapeHelper(TF_Graph * graph,TF_Output output,TF_Status * out_status,bool * unknown_shape)489 tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper(
490 TF_Graph* graph, TF_Output output, TF_Status* out_status,
491 bool* unknown_shape) {
492 // Allocate a single variable for holding the result for RVO.
493 tensorflow::gtl::InlinedVector<int64_t, 6> result;
494 *unknown_shape = false;
495 int num_dims = TF_GraphGetTensorNumDims(graph, output, out_status);
496 if (TF_GetCode(out_status) != TF_OK) {
497 return result;
498 }
499 // If shape is unknown, set boolean and return.
500 if (num_dims == -1) {
501 *unknown_shape = true;
502 return result;
503 }
504
505 // If shape is a scalar, avoid another C call and just return {}.
506 if (num_dims == 0) {
507 return result;
508 }
509
510 result.resize(num_dims);
511 TF_GraphGetTensorShape(graph, output, result.data(), num_dims, out_status);
512 return result;
513 }
514
TF_SessionPRunSetup_wrapper(TF_Session * session,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,const char ** out_handle,TF_Status * out_status)515 void TF_SessionPRunSetup_wrapper(TF_Session* session,
516 const std::vector<TF_Output>& inputs,
517 const std::vector<TF_Output>& outputs,
518 const std::vector<TF_Operation*>& targets,
519 const char** out_handle,
520 TF_Status* out_status) {
521 // Call TF_SessionPRunSetup() (and release GIL during execution)
522 Py_BEGIN_ALLOW_THREADS;
523 TF_SessionPRunSetup(session, inputs.data(), inputs.size(), outputs.data(),
524 outputs.size(), targets.data(), targets.size(),
525 out_handle, out_status);
526 Py_END_ALLOW_THREADS;
527 }
528
TF_SessionPRun_wrapper(TF_Session * session,const char * handle,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,TF_Status * out_status,std::vector<PyObject * > * py_outputs)529 void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
530 const std::vector<TF_Output>& inputs,
531 const std::vector<PyObject*>& input_ndarrays,
532 const std::vector<TF_Output>& outputs,
533 TF_Status* out_status,
534 std::vector<PyObject*>* py_outputs) {
535 const std::vector<TF_Operation*> targets;
536 TF_SessionRun_wrapper_helper(session, handle,
537 nullptr, // run_options
538 inputs, input_ndarrays, outputs, targets,
539 nullptr, // run_metadata
540 out_status, py_outputs);
541 // Release any unused ndarray references (see memory management comment in
542 // TF_SessionRun_wrapper_helper)
543 ClearDecrefCache();
544 }
545
GetOperationInputs(TF_Operation * oper)546 std::vector<TF_Output> GetOperationInputs(TF_Operation* oper) {
547 int num_inputs = TF_OperationNumInputs(oper);
548 std::vector<TF_Output> inputs(num_inputs);
549 TF_OperationAllInputs(oper, inputs.data(), inputs.size());
550 return inputs;
551 }
552
TF_OperationGetControlInputs_wrapper(TF_Operation * oper)553 std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
554 TF_Operation* oper) {
555 std::vector<TF_Operation*> control_inputs(TF_OperationNumControlInputs(oper));
556 TF_OperationGetControlInputs(oper, control_inputs.data(),
557 control_inputs.size());
558 return control_inputs;
559 }
560
TF_OperationGetControlOutputs_wrapper(TF_Operation * oper)561 std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper(
562 TF_Operation* oper) {
563 std::vector<TF_Operation*> control_outputs(
564 TF_OperationNumControlOutputs(oper));
565 TF_OperationGetControlOutputs(oper, control_outputs.data(),
566 control_outputs.size());
567 return control_outputs;
568 }
569
TF_OperationOutputConsumers_wrapper(TF_Output oper_out)570 std::vector<const char*> TF_OperationOutputConsumers_wrapper(
571 TF_Output oper_out) {
572 int num_consumers = TF_OperationOutputNumConsumers(oper_out);
573 std::vector<TF_Input> consumers(num_consumers);
574 TF_OperationOutputConsumers(oper_out, consumers.data(), num_consumers);
575
576 std::vector<const char*> consumer_names(num_consumers);
577 for (int i = 0; i < num_consumers; ++i) {
578 consumer_names[i] = TF_OperationName(consumers[i].oper);
579 }
580 return consumer_names;
581 }
582
TF_GraphToFunction_wrapper(const TF_Graph * fn_body,const char * fn_name,bool append_hash_to_fn_name,const std::vector<TF_Operation * > * opers,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const NameVector & output_names,const std::vector<TF_Operation * > * control_outputs,const NameVector & control_output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * out_status)583 TF_Function* TF_GraphToFunction_wrapper(
584 const TF_Graph* fn_body, const char* fn_name, bool append_hash_to_fn_name,
585 const std::vector<TF_Operation*>* opers,
586 const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs,
587 const NameVector& output_names,
588 const std::vector<TF_Operation*>* control_outputs,
589 const NameVector& control_output_names, const TF_FunctionOptions* opts,
590 const char* description, TF_Status* out_status) {
591 if (!output_names.empty() && output_names.size() != outputs.size()) {
592 Set_TF_Status_from_Status(
593 out_status,
594 errors::InvalidArgument(
595 "output names must be either empty or equal in size to outputs. ",
596 "output names size = ", output_names.size(),
597 " outputs size = ", outputs.size()));
598 return nullptr;
599 }
600
601 int nopers = -1;
602 const TF_Operation* const* opers_array = nullptr;
603 if (opers != nullptr) {
604 nopers = opers->size();
605 opers_array = opers->data();
606 }
607
608 const char** output_names_ptr =
609 output_names.empty() ? nullptr
610 : const_cast<const char**>(output_names.data());
611
612 const char** control_output_names_ptr =
613 control_output_names.empty()
614 ? nullptr
615 : const_cast<const char**>(control_output_names.data());
616
617 return TF_GraphToFunctionWithControlOutputs(
618 fn_body, fn_name, append_hash_to_fn_name, nopers, opers_array,
619 inputs.size(), inputs.data(), outputs.size(), outputs.data(),
620 output_names_ptr,
621 control_outputs == nullptr ? 0 : control_outputs->size(),
622 control_outputs == nullptr ? nullptr : control_outputs->data(),
623 control_output_names_ptr, opts, description, out_status);
624 }
625
TF_GraphSetOutputHandleShapesAndTypes_wrapper(TF_Graph * graph,TF_Output output,const std::vector<std::vector<int64_t>> & shapes,const std::vector<int> & ranks,const std::vector<TF_DataType> & types,TF_Status * status)626 void TF_GraphSetOutputHandleShapesAndTypes_wrapper(
627 TF_Graph* graph, TF_Output output,
628 const std::vector<std::vector<int64_t>>& shapes,
629 const std::vector<int>& ranks, const std::vector<TF_DataType>& types,
630 TF_Status* status) {
631 std::vector<const int64_t*> shapes_pointers(shapes.size());
632 for (int i = 0; i < shapes.size(); ++i) {
633 shapes_pointers[i] = ranks[i] <= 0 ? nullptr : &shapes[i][0];
634 }
635 TF_GraphSetOutputHandleShapesAndTypes(graph, output, shapes.size(),
636 shapes_pointers.data(), ranks.data(),
637 types.data(), status);
638 }
639
CreatePlaceholder(TF_Graph * graph,TF_Status * s,string && name,TF_DataType dtype,TF_Output * output)640 void CreatePlaceholder(TF_Graph* graph, TF_Status* s, string&& name,
641 TF_DataType dtype, TF_Output* output) {
642 TF_OperationDescription* desc =
643 TF_NewOperation(graph, "Placeholder", name.data());
644 TF_SetAttrType(desc, "dtype", dtype);
645 TF_Operation* op = TF_FinishOperation(desc, s);
646 output->oper = op;
647 output->index = 0;
648 }
649
TF_CreatePlaceholders(TF_Graph * graph,PyObject * dtypes,const char * prefix,TF_Status * status)650 std::vector<TF_Output> TF_CreatePlaceholders(TF_Graph* graph, PyObject* dtypes,
651 const char* prefix,
652 TF_Status* status) {
653 std::vector<TF_Output> outputs;
654 dtypes = PySequence_Fast(dtypes, "dtypes must be a sequence");
655 if (dtypes == nullptr) {
656 Set_TF_Status_from_Status(status, errors::Internal("dtypes is nullptr"));
657 return outputs;
658 }
659 Safe_PyObjectPtr dtypes_holder(make_safe(dtypes));
660 Py_ssize_t len = PySequence_Fast_GET_SIZE(dtypes);
661 outputs.reserve(len);
662 for (size_t i = 0; i < len; i++) {
663 PyObject* dtype = PySequence_Fast_GET_ITEM(dtypes, i);
664 if (!dtype) {
665 Set_TF_Status_from_Status(status,
666 errors::Internal("Could not get dtype ", i));
667 return outputs;
668 }
669 #if PY_MAJOR_VERSION >= 3
670 TF_DataType tf_datatype = static_cast<TF_DataType>(PyLong_AsLong(dtype));
671 #else
672 TF_DataType tf_datatype = static_cast<TF_DataType>(PyInt_AsLong(dtype));
673 #endif
674 outputs.push_back(TF_Output());
675 CreatePlaceholder(graph, status, strings::StrCat(prefix, i), tf_datatype,
676 &outputs.back());
677 if (!status->status.ok()) break;
678 }
679 return outputs;
680 }
681
TF_GraphSetTensorShape_wrapper(TF_Graph * graph,TF_Output output,const std::vector<int64_t> & dims,bool unknown_shape,TF_Status * status)682 void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
683 const std::vector<int64_t>& dims,
684 bool unknown_shape, TF_Status* status) {
685 if (unknown_shape) {
686 TF_GraphSetTensorShape(graph, output, nullptr, -1, status);
687 return;
688 }
689 TF_GraphSetTensorShape(graph, output, dims.data(), dims.size(), status);
690 }
691
TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(TF_ImportGraphDefResults * results)692 std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
693 TF_ImportGraphDefResults* results) {
694 int num_missing_unused_input_mappings;
695 const char** src_names;
696 int* src_indexes;
697 TF_ImportGraphDefResultsMissingUnusedInputMappings(
698 results, &num_missing_unused_input_mappings, &src_names, &src_indexes);
699 std::vector<string> input_strs(num_missing_unused_input_mappings);
700 for (int i = 0; i < num_missing_unused_input_mappings; ++i) {
701 input_strs[i] = TensorId(src_names[i], src_indexes[i]).ToString();
702 }
703 return input_strs;
704 }
705
TF_TryEvaluateConstant_wrapper(TF_Graph * graph,TF_Output output,TF_Status * status)706 PyObject* TF_TryEvaluateConstant_wrapper(TF_Graph* graph, TF_Output output,
707 TF_Status* status) {
708 TF_Tensor* result_tensor;
709 bool evaluated =
710 TF_TryEvaluateConstant(graph, output, &result_tensor, status);
711 if (!evaluated || TF_GetCode(status) != TF_OK) Py_RETURN_NONE;
712
713 Safe_TF_TensorPtr safe_result_tensor(result_tensor);
714 PyObject* out;
715 Status s = TF_TensorToPyArray(std::move(safe_result_tensor), &out);
716 Set_TF_Status_from_Status(status, s);
717 if (!s.ok()) Py_RETURN_NONE;
718 return PyArray_Return(reinterpret_cast<PyArrayObject*>(out));
719 }
720
721 } // namespace tensorflow
722