1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/python/client/tf_session_helper.h"
17
18 #include <cstring>
19
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/c_api_internal.h"
22 #include "tensorflow/c/tf_buffer_internal.h"
23 #include "tensorflow/c/tf_status_helper.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/attr_value_util.h"
27 #include "tensorflow/core/framework/log_memory.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/graph/tensor_id.h"
30 #include "tensorflow/core/lib/core/coding.h"
31 #include "tensorflow/core/lib/strings/stringprintf.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/equal_graph_def.h"
34 #include "tensorflow/python/client/session_ref.h"
35 #include "tensorflow/python/lib/core/ndarray_tensor.h"
36 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
37 #include "tensorflow/python/lib/core/safe_ptr.h"
38
39 namespace tensorflow {
40
41 namespace {
42
43 static const char* kFeedDictErrorMsg =
44 "feed_dict must be a dictionary mapping strings to NumPy arrays.";
45 } // end namespace
46
TF_NewSessionRef(TF_Graph * graph,const TF_SessionOptions * opts,TF_Status * status)47 TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts,
48 TF_Status* status) {
49 TF_Session* tf_session = TF_NewSession(graph, opts, status);
50 if (tf_session == nullptr) {
51 return nullptr;
52 }
53
54 Session* session = reinterpret_cast<Session*>(tf_session->session);
55 SessionRef* session_ref = new SessionRef(session);
56 tf_session->session = session_ref;
57 return tf_session;
58 }
59
TF_Run_wrapper_helper(TF_DeprecatedSession * session,const char * handle,const TF_Buffer * run_options,PyObject * feed_dict,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_outputs)60 void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
61 const TF_Buffer* run_options, PyObject* feed_dict,
62 const NameVector& output_names,
63 const NameVector& target_nodes,
64 TF_Status* out_status, PyObjectVector* out_values,
65 TF_Buffer* run_outputs) {
66 // 1. Convert the feed inputs to the appropriate form for TF_Run.
67 if (!PyDict_Check(feed_dict)) {
68 Set_TF_Status_from_Status(out_status,
69 errors::InvalidArgument(kFeedDictErrorMsg));
70 return;
71 }
72
73 NameVector input_names;
74 std::vector<Safe_TF_TensorPtr> inputs_safe; // Used to delete tensors.
75 TF_TensorVector inputs_unsafe; // Used to contain the arg to TF_Run.
76
77 PyObject* key;
78 PyObject* value;
79 Py_ssize_t pos = 0;
80 int index = 0;
81 Status s;
82
83 while (PyDict_Next(feed_dict, &pos, &key, &value)) {
84 char* key_string = PyBytes_AsString(key);
85 if (!key_string) {
86 Set_TF_Status_from_Status(out_status,
87 errors::InvalidArgument(kFeedDictErrorMsg));
88 return;
89 }
90 input_names.push_back(key_string);
91
92 inputs_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
93 s = NdarrayToTensor(nullptr /*ctx*/, value, &inputs_safe.back());
94 if (!s.ok()) {
95 Set_TF_Status_from_Status(out_status, s);
96 return;
97 }
98 inputs_unsafe.push_back(inputs_safe.back().get());
99 ++index;
100 }
101
102 // 2. Allocate a container for the output data.
103 TF_TensorVector outputs(output_names.size());
104
105 // In case any tensors were leftover from previous runs we might as well clear
106 // them here.
107 ClearDecrefCache();
108
109 // 3. Actually call TF_Run().
110 Py_BEGIN_ALLOW_THREADS;
111 if (handle == nullptr) {
112 TF_Run(session, run_options, input_names.data(), inputs_unsafe.data(),
113 input_names.size(), const_cast<const char**>(output_names.data()),
114 outputs.data(), output_names.size(),
115 const_cast<const char**>(target_nodes.data()), target_nodes.size(),
116 run_outputs, out_status);
117 } else {
118 TF_PRun(session, handle, input_names.data(), inputs_unsafe.data(),
119 input_names.size(), const_cast<const char**>(output_names.data()),
120 outputs.data(), output_names.size(),
121 const_cast<const char**>(target_nodes.data()), target_nodes.size(),
122 out_status);
123 }
124
125 Py_END_ALLOW_THREADS;
126
127 // Decref any numpy arrays we are not using anymore.
128 ClearDecrefCache();
129
130 if (TF_GetCode(out_status) != TF_OK) {
131 return;
132 }
133
134 // 4. We now own the fetched tensors, so set up a safe container to
135 // delete them when we exit this scope.
136 std::vector<Safe_TF_TensorPtr> tf_outputs_safe;
137 for (const auto& output : outputs) {
138 tf_outputs_safe.emplace_back(make_safe(output));
139 }
140
141 // 5. Convert the fetched tensors into numpy ndarrays. Store them in a safe
142 // container so that we do not leak
143 std::vector<Safe_PyObjectPtr> py_outputs_safe;
144 for (size_t i = 0; i < output_names.size(); ++i) {
145 PyObject* py_array;
146 s = TF_TensorToPyArray(std::move(tf_outputs_safe[i]), &py_array);
147 if (!s.ok()) {
148 Set_TF_Status_from_Status(out_status, s);
149 return;
150 }
151 py_outputs_safe.emplace_back(
152 make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array))));
153 }
154
155 // 6. If we reach this point, we have successfully built a list of objects
156 // so we can release them from the safe container.
157 for (auto& output : py_outputs_safe) {
158 out_values->push_back(output.release());
159 }
160 }
161
162 // Wrapper for TF_Run that converts the arguments to appropriate types.
163 // If *out_status is OK, the caller becomes the owner of the PyObjects
164 // in *out_values.
TF_Run_wrapper(TF_DeprecatedSession * session,const TF_Buffer * run_options,PyObject * feed_dict,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_outputs)165 void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options,
166 PyObject* feed_dict, const NameVector& output_names,
167 const NameVector& target_nodes, TF_Status* out_status,
168 PyObjectVector* out_values, TF_Buffer* run_outputs) {
169 TF_Run_wrapper_helper(session, nullptr, run_options, feed_dict, output_names,
170 target_nodes, out_status, out_values, run_outputs);
171 ClearDecrefCache();
172 }
173
174 namespace {
MakeCallableHelper(tensorflow::Session * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * out_status)175 void MakeCallableHelper(tensorflow::Session* session,
176 const TF_Buffer* callable_options, int64_t* out_handle,
177 TF_Status* out_status) {
178 tensorflow::CallableOptions callable_options_proto;
179 if (callable_options != nullptr &&
180 !callable_options_proto.ParseFromArray(callable_options->data,
181 callable_options->length)) {
182 Set_TF_Status_from_Status(
183 out_status,
184 errors::InvalidArgument("Unparseable CallableOptions proto"));
185 return;
186 }
187 tensorflow::Session::CallableHandle handle;
188 Status s = session->MakeCallable(callable_options_proto, &handle);
189 if (!s.ok()) {
190 Set_TF_Status_from_Status(out_status, s);
191 return;
192 }
193 *out_handle = handle;
194 }
195 } // namespace
196
TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * status)197 void TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession* session,
198 const TF_Buffer* callable_options,
199 int64_t* out_handle, TF_Status* status) {
200 MakeCallableHelper(session->session, callable_options, out_handle, status);
201 }
TF_SessionMakeCallable(TF_Session * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * status)202 void TF_SessionMakeCallable(TF_Session* session,
203 const TF_Buffer* callable_options,
204 int64_t* out_handle, TF_Status* status) {
205 MakeCallableHelper(session->session, callable_options, out_handle, status);
206 }
207
208 namespace {
RunCallableHelper(tensorflow::Session * session,int64_t handle,PyObject * feed_values,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_metadata)209 void RunCallableHelper(tensorflow::Session* session, int64_t handle,
210 PyObject* feed_values, TF_Status* out_status,
211 PyObjectVector* out_values, TF_Buffer* run_metadata) {
212 // Convert feed values to a vector of tensorflow::Tensor objects.
213 std::vector<Tensor> input_tensors;
214 Status s;
215 {
216 feed_values =
217 PySequence_Fast(feed_values, "feed_values must be a sequence");
218 if (feed_values == nullptr) return;
219 Safe_PyObjectPtr feed_values_holder(make_safe(feed_values));
220 Py_ssize_t len = PySequence_Fast_GET_SIZE(feed_values);
221 input_tensors.reserve(len);
222 for (Py_ssize_t i = 0; i < len; ++i) {
223 PyObject* elem = PySequence_Fast_GET_ITEM(feed_values, i);
224 if (!elem) {
225 Set_TF_Status_from_Status(
226 out_status, errors::Internal("Could not get feed value ", i));
227 return;
228 }
229 Tensor t;
230 s = NdarrayToTensor(elem, &t);
231 if (!s.ok()) {
232 Set_TF_Status_from_Status(out_status, s);
233 return;
234 }
235 input_tensors.push_back(std::move(t));
236 }
237 }
238
239 RunMetadata run_metadata_proto;
240
241 // Run the callable.
242 std::vector<Tensor> output_tensors;
243 Py_BEGIN_ALLOW_THREADS;
244 s = session->RunCallable(handle, input_tensors, &output_tensors,
245 &run_metadata_proto);
246 Py_END_ALLOW_THREADS;
247
248 if (!s.ok()) {
249 Set_TF_Status_from_Status(out_status, s);
250 return;
251 }
252
253 // If requested, serialize the RunMetadata to pass it back to the caller.
254 if (run_metadata != nullptr) {
255 s = MessageToBuffer(run_metadata_proto, run_metadata);
256 if (!s.ok()) {
257 Set_TF_Status_from_Status(out_status, s);
258 return;
259 }
260 }
261
262 // Convert results to NumPy arrays. Since this can fail, stage the
263 // results via a safe container that takes care of decreasing the
264 // reference count on failure.
265 std::vector<Safe_PyObjectPtr> py_outputs_safe;
266 py_outputs_safe.reserve(output_tensors.size());
267 for (const Tensor& output : output_tensors) {
268 PyObject* py_array;
269 s = TensorToNdarray(output, &py_array);
270 if (!s.ok()) {
271 Set_TF_Status_from_Status(out_status, s);
272 return;
273 }
274 py_outputs_safe.push_back(
275 make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array))));
276 }
277
278 // If we reach this point, we have successfully built a list of objects
279 // so we can release them from the safe container.
280 out_values->reserve(py_outputs_safe.size());
281 for (auto& output : py_outputs_safe) {
282 out_values->push_back(output.release());
283 }
284 }
285 } // namespace
286
TF_DeprecatedSessionRunCallable(TF_DeprecatedSession * session,int64_t handle,PyObject * feed_values,PyObjectVector * out_values,TF_Buffer * run_metadata,TF_Status * status)287 void TF_DeprecatedSessionRunCallable(TF_DeprecatedSession* session,
288 int64_t handle, PyObject* feed_values,
289 PyObjectVector* out_values,
290 TF_Buffer* run_metadata,
291 TF_Status* status) {
292 RunCallableHelper(session->session, handle, feed_values, status, out_values,
293 run_metadata);
294 ClearDecrefCache();
295 }
TF_SessionRunCallable(TF_Session * session,int64_t handle,PyObject * feed_values,PyObjectVector * out_values,TF_Buffer * run_metadata,TF_Status * status)296 void TF_SessionRunCallable(TF_Session* session, int64_t handle,
297 PyObject* feed_values, PyObjectVector* out_values,
298 TF_Buffer* run_metadata, TF_Status* status) {
299 RunCallableHelper(session->session, handle, feed_values, status, out_values,
300 run_metadata);
301 ClearDecrefCache();
302 }
303
TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession * session,int64_t handle,TF_Status * status)304 void TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession* session,
305 int64_t handle, TF_Status* status) {
306 Set_TF_Status_from_Status(status, session->session->ReleaseCallable(handle));
307 }
TF_SessionReleaseCallable(TF_Session * session,int64_t handle,TF_Status * status)308 void TF_SessionReleaseCallable(TF_Session* session, int64_t handle,
309 TF_Status* status) {
310 Set_TF_Status_from_Status(status, session->session->ReleaseCallable(handle));
311 }
312
313 // Wrapper for TF_PRunSetup that converts the arguments to appropriate types.
314 // If *out_status is OK, the caller becomes the owner of *out_handle.
TF_PRunSetup_wrapper(TF_DeprecatedSession * session,const NameVector & input_names,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,const char ** out_handle)315 void TF_PRunSetup_wrapper(TF_DeprecatedSession* session,
316 const NameVector& input_names,
317 const NameVector& output_names,
318 const NameVector& target_nodes, TF_Status* out_status,
319 const char** out_handle) {
320 Py_BEGIN_ALLOW_THREADS;
321 TF_PRunSetup(
322 session, const_cast<const char**>(input_names.data()), input_names.size(),
323 const_cast<const char**>(output_names.data()), output_names.size(),
324 const_cast<const char**>(target_nodes.data()), target_nodes.size(),
325 out_handle, out_status);
326 Py_END_ALLOW_THREADS;
327 }
328
329 // Wrapper for TF_PRun that converts the arguments to appropriate types.
330 // If *out_status is OK, the caller becomes the owner of the PyObjects
331 // in *out_values.
TF_PRun_wrapper(TF_DeprecatedSession * session,const char * handle,PyObject * feed_dict,const NameVector & output_names,TF_Status * out_status,PyObjectVector * out_values)332 void TF_PRun_wrapper(TF_DeprecatedSession* session, const char* handle,
333 PyObject* feed_dict, const NameVector& output_names,
334 TF_Status* out_status, PyObjectVector* out_values) {
335 TF_Run_wrapper_helper(session, handle, nullptr, feed_dict, output_names,
336 NameVector(), out_status, out_values, nullptr);
337 ClearDecrefCache();
338 }
339
340 // Wrapper for TF_Reset that converts the string vectors to character arrays.
TF_Reset_wrapper(const TF_SessionOptions * opt,const NameVector & containers,TF_Status * status)341 void TF_Reset_wrapper(const TF_SessionOptions* opt,
342 const NameVector& containers, TF_Status* status) {
343 TF_Reset(opt, const_cast<const char**>(containers.data()), containers.size(),
344 status);
345 }
346
TF_SessionRun_wrapper_helper(TF_Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,TF_Buffer * run_metadata,TF_Status * out_status,std::vector<PyObject * > * py_outputs)347 void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
348 const TF_Buffer* run_options,
349 const std::vector<TF_Output>& inputs,
350 const std::vector<PyObject*>& input_ndarrays,
351 const std::vector<TF_Output>& outputs,
352 const std::vector<TF_Operation*>& targets,
353 TF_Buffer* run_metadata,
354 TF_Status* out_status,
355 std::vector<PyObject*>* py_outputs) {
356 DCHECK_EQ(inputs.size(), input_ndarrays.size());
357 DCHECK(py_outputs != nullptr);
358 DCHECK(py_outputs->empty());
359 Status s;
360
361 // Convert input ndarray PyObjects to TF_Tensors. We maintain a continuous
362 // array of TF_Tensor*s as well as scoped containers to make sure they're
363 // cleaned up properly.
364 //
365 // Memory management:
366 // NdarrayToTensor() creates a new ndarray PyObject from the input
367 // ndarray. We manage the new ndarray's lifetime in order to keep the
368 // underlying data buffer alive (the new ndarray also guarantees a contiguous
369 // data buffer). The new ndarray's data buffer is used to create the
370 // corresponding TF_Tensor. The TF_Tensor's deallocator will queue the new
371 // ndarray to be decref'd by the next ClearDecrefCache() call (we can't call
372 // Py_DECREF in the deallocator directly because the GIL must be held).
373 //
374 // Note that TF_Tensor may directly delegate its data and deallocator to a
375 // TensorBuffer, which may outlive the TF_Tensor (e.g. if the tensor gets
376 // queued or assigned to a variable).
377 TF_TensorVector input_vals;
378 std::vector<Safe_TF_TensorPtr> input_vals_safe;
379 for (PyObject* ndarray : input_ndarrays) {
380 input_vals_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
381 s = NdarrayToTensor(nullptr, ndarray, &input_vals_safe.back());
382 if (!s.ok()) {
383 Set_TF_Status_from_Status(out_status, s);
384 return;
385 }
386 input_vals.push_back(input_vals_safe.back().get());
387 }
388
389 // Allocate space for output TF_Tensor*s
390 TF_TensorVector output_vals(outputs.size());
391
392 // Clear up any unused memory leftover from previous runs
393 ClearDecrefCache();
394
395 // Call TF_SessionRun() (and release GIL during execution)
396 Py_BEGIN_ALLOW_THREADS;
397 if (handle == nullptr) {
398 TF_SessionRun(session, run_options, inputs.data(), input_vals.data(),
399 inputs.size(), outputs.data(), output_vals.data(),
400 outputs.size(), targets.data(), targets.size(), run_metadata,
401 out_status);
402 } else {
403 TF_SessionPRun(session, handle, inputs.data(), input_vals.data(),
404 inputs.size(), outputs.data(), output_vals.data(),
405 outputs.size(), targets.data(), targets.size(), out_status);
406 }
407 Py_END_ALLOW_THREADS;
408
409 // Create scoped containers for output tensors
410 std::vector<Safe_TF_TensorPtr> output_vals_safe;
411 for (TF_Tensor* output : output_vals) {
412 output_vals_safe.emplace_back(make_safe(output));
413 }
414
415 // Convert outputs to ndarrays (in scoped containers)
416 std::vector<Safe_PyObjectPtr> py_outputs_safe;
417 for (size_t i = 0; i < outputs.size(); ++i) {
418 PyObject* py_array;
419 s = TF_TensorToPyArray(std::move(output_vals_safe[i]), &py_array);
420 if (!s.ok()) {
421 Set_TF_Status_from_Status(out_status, s);
422 return;
423 }
424 py_outputs_safe.emplace_back(
425 make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array))));
426 }
427
428 // If we reach this point, we have successfully built a list of objects so we
429 // can release them from the safe container into the return vector.
430 for (size_t i = 0; i < outputs.size(); ++i) {
431 py_outputs->push_back(py_outputs_safe[i].release());
432 }
433 }
434
TF_SessionRun_wrapper(TF_Session * session,const TF_Buffer * run_options,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,TF_Buffer * run_metadata,TF_Status * out_status,std::vector<PyObject * > * py_outputs)435 void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options,
436 const std::vector<TF_Output>& inputs,
437 const std::vector<PyObject*>& input_ndarrays,
438 const std::vector<TF_Output>& outputs,
439 const std::vector<TF_Operation*>& targets,
440 TF_Buffer* run_metadata, TF_Status* out_status,
441 std::vector<PyObject*>* py_outputs) {
442 TF_SessionRun_wrapper_helper(session, nullptr, run_options, inputs,
443 input_ndarrays, outputs, targets, run_metadata,
444 out_status, py_outputs);
445 // Release any unused ndarray references (see memory management comment in
446 // TF_SessionRun_wrapper_helper)
447 ClearDecrefCache();
448 }
449
EqualGraphDefWrapper(const string & actual,const string & expected)450 string EqualGraphDefWrapper(const string& actual, const string& expected) {
451 GraphDef actual_def;
452 if (!actual_def.ParseFromString(actual)) {
453 return "actual is not a valid serialized GraphDef";
454 }
455 GraphDef expected_def;
456 if (!expected_def.ParseFromString(expected)) {
457 return "expected is not a valid serialized GraphDef";
458 }
459 string diff;
460 return EqualGraphDef(actual_def, expected_def, &diff) ? "" : diff;
461 }
462
EqualAttrValueWrapper(const string & actual,const string & expected)463 string EqualAttrValueWrapper(const string& actual, const string& expected) {
464 AttrValue actual_attr_value;
465 if (!actual_attr_value.ParseFromString(actual)) {
466 return "actual is not a valid serialized AttrValue";
467 }
468
469 AttrValue expected_attr_value;
470 if (!expected_attr_value.ParseFromString(expected)) {
471 return "expected is not a valid serialized AttrValue";
472 }
473
474 string diff;
475 if (!AreAttrValuesEqual(actual_attr_value, expected_attr_value)) {
476 diff = strings::Printf(
477 "Actual AttrValue %s does not match Expected AttrValue %s.",
478 SummarizeAttrValue(actual_attr_value).c_str(),
479 SummarizeAttrValue(expected_attr_value).c_str());
480 }
481 return diff;
482 }
483
484 // Return value set to 6 inlined elements so it fits in a 64-byte cache line.
TF_GraphGetTensorShapeHelper(TF_Graph * graph,TF_Output output,TF_Status * out_status,bool * unknown_shape)485 tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper(
486 TF_Graph* graph, TF_Output output, TF_Status* out_status,
487 bool* unknown_shape) {
488 // Allocate a single variable for holding the result for RVO.
489 tensorflow::gtl::InlinedVector<int64_t, 6> result;
490 *unknown_shape = false;
491 int num_dims = TF_GraphGetTensorNumDims(graph, output, out_status);
492 if (TF_GetCode(out_status) != TF_OK) {
493 return result;
494 }
495 // If shape is unknown, set boolean and return.
496 if (num_dims == -1) {
497 *unknown_shape = true;
498 return result;
499 }
500
501 // If shape is a scalar, avoid another C call and just return {}.
502 if (num_dims == 0) {
503 return result;
504 }
505
506 result.resize(num_dims);
507 TF_GraphGetTensorShape(graph, output, result.data(), num_dims, out_status);
508 return result;
509 }
510
TF_SessionPRunSetup_wrapper(TF_Session * session,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,const char ** out_handle,TF_Status * out_status)511 void TF_SessionPRunSetup_wrapper(TF_Session* session,
512 const std::vector<TF_Output>& inputs,
513 const std::vector<TF_Output>& outputs,
514 const std::vector<TF_Operation*>& targets,
515 const char** out_handle,
516 TF_Status* out_status) {
517 // Call TF_SessionPRunSetup() (and release GIL during execution)
518 Py_BEGIN_ALLOW_THREADS;
519 TF_SessionPRunSetup(session, inputs.data(), inputs.size(), outputs.data(),
520 outputs.size(), targets.data(), targets.size(),
521 out_handle, out_status);
522 Py_END_ALLOW_THREADS;
523 }
524
TF_SessionPRun_wrapper(TF_Session * session,const char * handle,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,TF_Status * out_status,std::vector<PyObject * > * py_outputs)525 void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
526 const std::vector<TF_Output>& inputs,
527 const std::vector<PyObject*>& input_ndarrays,
528 const std::vector<TF_Output>& outputs,
529 TF_Status* out_status,
530 std::vector<PyObject*>* py_outputs) {
531 const std::vector<TF_Operation*> targets;
532 TF_SessionRun_wrapper_helper(session, handle,
533 nullptr, // run_options
534 inputs, input_ndarrays, outputs, targets,
535 nullptr, // run_metadata
536 out_status, py_outputs);
537 // Release any unused ndarray references (see memory management comment in
538 // TF_SessionRun_wrapper_helper)
539 ClearDecrefCache();
540 }
541
GetOperationInputs(TF_Operation * oper)542 std::vector<TF_Output> GetOperationInputs(TF_Operation* oper) {
543 int num_inputs = TF_OperationNumInputs(oper);
544 std::vector<TF_Output> inputs(num_inputs);
545 TF_OperationAllInputs(oper, inputs.data(), inputs.size());
546 return inputs;
547 }
548
TF_OperationGetControlInputs_wrapper(TF_Operation * oper)549 std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
550 TF_Operation* oper) {
551 std::vector<TF_Operation*> control_inputs(TF_OperationNumControlInputs(oper));
552 TF_OperationGetControlInputs(oper, control_inputs.data(),
553 control_inputs.size());
554 return control_inputs;
555 }
556
TF_OperationGetControlOutputs_wrapper(TF_Operation * oper)557 std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper(
558 TF_Operation* oper) {
559 std::vector<TF_Operation*> control_outputs(
560 TF_OperationNumControlOutputs(oper));
561 TF_OperationGetControlOutputs(oper, control_outputs.data(),
562 control_outputs.size());
563 return control_outputs;
564 }
565
TF_OperationOutputConsumers_wrapper(TF_Output oper_out)566 std::vector<const char*> TF_OperationOutputConsumers_wrapper(
567 TF_Output oper_out) {
568 int num_consumers = TF_OperationOutputNumConsumers(oper_out);
569 std::vector<TF_Input> consumers(num_consumers);
570 TF_OperationOutputConsumers(oper_out, consumers.data(), num_consumers);
571
572 std::vector<const char*> consumer_names(num_consumers);
573 for (int i = 0; i < num_consumers; ++i) {
574 consumer_names[i] = TF_OperationName(consumers[i].oper);
575 }
576 return consumer_names;
577 }
578
TF_GraphToFunction_wrapper(const TF_Graph * fn_body,const char * fn_name,bool append_hash_to_fn_name,const std::vector<TF_Operation * > * opers,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const NameVector & output_names,const std::vector<TF_Operation * > * control_outputs,const NameVector & control_output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * out_status)579 TF_Function* TF_GraphToFunction_wrapper(
580 const TF_Graph* fn_body, const char* fn_name, bool append_hash_to_fn_name,
581 const std::vector<TF_Operation*>* opers,
582 const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs,
583 const NameVector& output_names,
584 const std::vector<TF_Operation*>* control_outputs,
585 const NameVector& control_output_names, const TF_FunctionOptions* opts,
586 const char* description, TF_Status* out_status) {
587 if (!output_names.empty() && output_names.size() != outputs.size()) {
588 Set_TF_Status_from_Status(
589 out_status,
590 errors::InvalidArgument(
591 "output names must be either empty or equal in size to outputs. ",
592 "output names size = ", output_names.size(),
593 " outputs size = ", outputs.size()));
594 return nullptr;
595 }
596
597 int nopers = -1;
598 const TF_Operation* const* opers_array = nullptr;
599 if (opers != nullptr) {
600 nopers = opers->size();
601 opers_array = opers->data();
602 }
603
604 const char** output_names_ptr =
605 output_names.empty() ? nullptr
606 : const_cast<const char**>(output_names.data());
607
608 const char** control_output_names_ptr =
609 control_output_names.empty()
610 ? nullptr
611 : const_cast<const char**>(control_output_names.data());
612
613 return TF_GraphToFunctionWithControlOutputs(
614 fn_body, fn_name, append_hash_to_fn_name, nopers, opers_array,
615 inputs.size(), inputs.data(), outputs.size(), outputs.data(),
616 output_names_ptr,
617 control_outputs == nullptr ? 0 : control_outputs->size(),
618 control_outputs == nullptr ? nullptr : control_outputs->data(),
619 control_output_names_ptr, opts, description, out_status);
620 }
621
TF_GraphSetOutputHandleShapesAndTypes_wrapper(TF_Graph * graph,TF_Output output,const std::vector<std::vector<int64_t>> & shapes,const std::vector<int> & ranks,const std::vector<TF_DataType> & types,TF_Status * status)622 void TF_GraphSetOutputHandleShapesAndTypes_wrapper(
623 TF_Graph* graph, TF_Output output,
624 const std::vector<std::vector<int64_t>>& shapes,
625 const std::vector<int>& ranks, const std::vector<TF_DataType>& types,
626 TF_Status* status) {
627 std::vector<const int64_t*> shapes_pointers(shapes.size());
628 for (int i = 0; i < shapes.size(); ++i) {
629 shapes_pointers[i] = ranks[i] <= 0 ? nullptr : &shapes[i][0];
630 }
631 TF_GraphSetOutputHandleShapesAndTypes(graph, output, shapes.size(),
632 shapes_pointers.data(), ranks.data(),
633 types.data(), status);
634 }
635
CreatePlaceholder(TF_Graph * graph,TF_Status * s,string && name,TF_DataType dtype,TF_Output * output)636 void CreatePlaceholder(TF_Graph* graph, TF_Status* s, string&& name,
637 TF_DataType dtype, TF_Output* output) {
638 TF_OperationDescription* desc =
639 TF_NewOperation(graph, "Placeholder", name.data());
640 TF_SetAttrType(desc, "dtype", dtype);
641 TF_Operation* op = TF_FinishOperation(desc, s);
642 output->oper = op;
643 output->index = 0;
644 }
645
TF_CreatePlaceholders(TF_Graph * graph,PyObject * dtypes,const char * prefix,TF_Status * status)646 std::vector<TF_Output> TF_CreatePlaceholders(TF_Graph* graph, PyObject* dtypes,
647 const char* prefix,
648 TF_Status* status) {
649 std::vector<TF_Output> outputs;
650 dtypes = PySequence_Fast(dtypes, "dtypes must be a sequence");
651 if (dtypes == nullptr) {
652 Set_TF_Status_from_Status(status, errors::Internal("dtypes is nullptr"));
653 return outputs;
654 }
655 Safe_PyObjectPtr dtypes_holder(make_safe(dtypes));
656 Py_ssize_t len = PySequence_Fast_GET_SIZE(dtypes);
657 outputs.reserve(len);
658 for (size_t i = 0; i < len; i++) {
659 PyObject* dtype = PySequence_Fast_GET_ITEM(dtypes, i);
660 if (!dtype) {
661 Set_TF_Status_from_Status(status,
662 errors::Internal("Could not get dtype ", i));
663 return outputs;
664 }
665 #if PY_MAJOR_VERSION >= 3
666 TF_DataType tf_datatype = static_cast<TF_DataType>(PyLong_AsLong(dtype));
667 #else
668 TF_DataType tf_datatype = static_cast<TF_DataType>(PyInt_AsLong(dtype));
669 #endif
670 outputs.push_back(TF_Output());
671 CreatePlaceholder(graph, status, strings::StrCat(prefix, i), tf_datatype,
672 &outputs.back());
673 if (!status->status.ok()) break;
674 }
675 return outputs;
676 }
677
TF_GraphSetTensorShape_wrapper(TF_Graph * graph,TF_Output output,const std::vector<int64_t> & dims,bool unknown_shape,TF_Status * status)678 void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
679 const std::vector<int64_t>& dims,
680 bool unknown_shape, TF_Status* status) {
681 if (unknown_shape) {
682 TF_GraphSetTensorShape(graph, output, nullptr, -1, status);
683 return;
684 }
685 TF_GraphSetTensorShape(graph, output, dims.data(), dims.size(), status);
686 }
687
TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(TF_ImportGraphDefResults * results)688 std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
689 TF_ImportGraphDefResults* results) {
690 int num_missing_unused_input_mappings;
691 const char** src_names;
692 int* src_indexes;
693 TF_ImportGraphDefResultsMissingUnusedInputMappings(
694 results, &num_missing_unused_input_mappings, &src_names, &src_indexes);
695 std::vector<string> input_strs(num_missing_unused_input_mappings);
696 for (int i = 0; i < num_missing_unused_input_mappings; ++i) {
697 input_strs[i] = TensorId(src_names[i], src_indexes[i]).ToString();
698 }
699 return input_strs;
700 }
701
TF_TryEvaluateConstant_wrapper(TF_Graph * graph,TF_Output output,TF_Status * status)702 PyObject* TF_TryEvaluateConstant_wrapper(TF_Graph* graph, TF_Output output,
703 TF_Status* status) {
704 TF_Tensor* result_tensor;
705 bool evaluated =
706 TF_TryEvaluateConstant(graph, output, &result_tensor, status);
707 if (!evaluated || TF_GetCode(status) != TF_OK) Py_RETURN_NONE;
708
709 Safe_TF_TensorPtr safe_result_tensor(result_tensor);
710 PyObject* out;
711 Status s = TF_TensorToPyArray(std::move(safe_result_tensor), &out);
712 Set_TF_Status_from_Status(status, s);
713 if (!s.ok()) Py_RETURN_NONE;
714 return PyArray_Return(reinterpret_cast<PyArrayObject*>(out));
715 }
716
717 } // namespace tensorflow
718