• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/client/tf_session_helper.h"
17 
18 #include <cstring>
19 
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/c_api_internal.h"
22 #include "tensorflow/c/tf_status_helper.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/log_memory.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/graph/tensor_id.h"
29 #include "tensorflow/core/lib/core/coding.h"
30 #include "tensorflow/core/lib/strings/stringprintf.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/equal_graph_def.h"
33 #include "tensorflow/python/client/session_ref.h"
34 #include "tensorflow/python/lib/core/ndarray_tensor.h"
35 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
36 #include "tensorflow/python/lib/core/safe_ptr.h"
37 
38 namespace tensorflow {
39 
40 namespace {
41 
42 static const char* kFeedDictErrorMsg =
43     "feed_dict must be a dictionary mapping strings to NumPy arrays.";
44 }  // end namespace
45 
TF_NewSessionRef(TF_Graph * graph,const TF_SessionOptions * opts,TF_Status * status)46 TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts,
47                              TF_Status* status) {
48   TF_Session* tf_session = TF_NewSession(graph, opts, status);
49   if (tf_session == nullptr) {
50     return nullptr;
51   }
52 
53   Session* session = reinterpret_cast<Session*>(tf_session->session);
54   SessionRef* session_ref = new SessionRef(session);
55   tf_session->session = session_ref;
56   return tf_session;
57 }
58 
TF_Run_wrapper_helper(TF_DeprecatedSession * session,const char * handle,const TF_Buffer * run_options,PyObject * feed_dict,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_outputs)59 void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
60                            const TF_Buffer* run_options, PyObject* feed_dict,
61                            const NameVector& output_names,
62                            const NameVector& target_nodes,
63                            TF_Status* out_status, PyObjectVector* out_values,
64                            TF_Buffer* run_outputs) {
65   // 1. Convert the feed inputs to the appropriate form for TF_Run.
66   if (!PyDict_Check(feed_dict)) {
67     Set_TF_Status_from_Status(out_status,
68                               errors::InvalidArgument(kFeedDictErrorMsg));
69     return;
70   }
71 
72   NameVector input_names;
73   std::vector<Safe_TF_TensorPtr> inputs_safe;  // Used to delete tensors.
74   TF_TensorVector inputs_unsafe;  // Used to contain the arg to TF_Run.
75 
76   PyObject* key;
77   PyObject* value;
78   Py_ssize_t pos = 0;
79   int index = 0;
80   Status s;
81 
82   while (PyDict_Next(feed_dict, &pos, &key, &value)) {
83     char* key_string = PyBytes_AsString(key);
84     if (!key_string) {
85       Set_TF_Status_from_Status(out_status,
86                                 errors::InvalidArgument(kFeedDictErrorMsg));
87       return;
88     }
89     input_names.push_back(key_string);
90 
91     inputs_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
92     s = PyArrayToTF_Tensor(value, &inputs_safe.back());
93     if (!s.ok()) {
94       Set_TF_Status_from_Status(out_status, s);
95       return;
96     }
97     inputs_unsafe.push_back(inputs_safe.back().get());
98     ++index;
99   }
100 
101   // 2. Allocate a container for the output data.
102   TF_TensorVector outputs(output_names.size());
103 
104   // In case any tensors were leftover from previous runs we might as well clear
105   // them here.
106   ClearDecrefCache();
107 
108   // 3. Actually call TF_Run().
109   Py_BEGIN_ALLOW_THREADS;
110   if (handle == nullptr) {
111     TF_Run(session, run_options, input_names.data(), inputs_unsafe.data(),
112            input_names.size(), const_cast<const char**>(output_names.data()),
113            outputs.data(), output_names.size(),
114            const_cast<const char**>(target_nodes.data()), target_nodes.size(),
115            run_outputs, out_status);
116   } else {
117     TF_PRun(session, handle, input_names.data(), inputs_unsafe.data(),
118             input_names.size(), const_cast<const char**>(output_names.data()),
119             outputs.data(), output_names.size(),
120             const_cast<const char**>(target_nodes.data()), target_nodes.size(),
121             out_status);
122   }
123 
124   Py_END_ALLOW_THREADS;
125 
126   // Decref any numpy arrays we are not using anymore.
127   ClearDecrefCache();
128 
129   if (TF_GetCode(out_status) != TF_OK) {
130     return;
131   }
132 
133   // 4. We now own the fetched tensors, so set up a safe container to
134   // delete them when we exit this scope.
135   std::vector<Safe_TF_TensorPtr> tf_outputs_safe;
136   for (const auto& output : outputs) {
137     tf_outputs_safe.emplace_back(make_safe(output));
138   }
139 
140   // 5. Convert the fetched tensors into numpy ndarrays. Store them in a safe
141   // container so that we do not leak
142   std::vector<Safe_PyObjectPtr> py_outputs_safe;
143   for (size_t i = 0; i < output_names.size(); ++i) {
144     PyObject* py_array;
145     s = TF_TensorToPyArray(std::move(tf_outputs_safe[i]), &py_array);
146     if (!s.ok()) {
147       Set_TF_Status_from_Status(out_status, s);
148       return;
149     }
150     py_outputs_safe.emplace_back(make_safe(py_array));
151   }
152 
153   // 6. If we reach this point, we have successfully built a list of objects
154   // so we can release them from the safe container.
155   for (auto& output : py_outputs_safe) {
156     out_values->push_back(output.release());
157   }
158 }
159 
160 // Wrapper for TF_Run that converts the arguments to appropriate types.
161 // If *out_status is OK, the caller becomes the owner of the PyObjects
162 // in *out_values.
TF_Run_wrapper(TF_DeprecatedSession * session,const TF_Buffer * run_options,PyObject * feed_dict,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_outputs)163 void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options,
164                     PyObject* feed_dict, const NameVector& output_names,
165                     const NameVector& target_nodes, TF_Status* out_status,
166                     PyObjectVector* out_values, TF_Buffer* run_outputs) {
167   TF_Run_wrapper_helper(session, nullptr, run_options, feed_dict, output_names,
168                         target_nodes, out_status, out_values, run_outputs);
169   ClearDecrefCache();
170 }
171 
172 namespace {
MakeCallableHelper(tensorflow::Session * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * out_status)173 void MakeCallableHelper(tensorflow::Session* session,
174                         const TF_Buffer* callable_options, int64_t* out_handle,
175                         TF_Status* out_status) {
176   tensorflow::CallableOptions callable_options_proto;
177   if (callable_options != nullptr &&
178       !callable_options_proto.ParseFromArray(callable_options->data,
179                                              callable_options->length)) {
180     Set_TF_Status_from_Status(
181         out_status,
182         errors::InvalidArgument("Unparseable CallableOptions proto"));
183     return;
184   }
185   tensorflow::Session::CallableHandle handle;
186   Status s = session->MakeCallable(callable_options_proto, &handle);
187   if (!s.ok()) {
188     Set_TF_Status_from_Status(out_status, s);
189     return;
190   }
191   *out_handle = handle;
192 }
193 }  // namespace
194 
TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * out_status)195 void TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession* session,
196                                       const TF_Buffer* callable_options,
197                                       int64_t* out_handle,
198                                       TF_Status* out_status) {
199   MakeCallableHelper(session->session, callable_options, out_handle,
200                      out_status);
201 }
TF_SessionMakeCallable(TF_Session * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * out_status)202 void TF_SessionMakeCallable(TF_Session* session,
203                             const TF_Buffer* callable_options,
204                             int64_t* out_handle, TF_Status* out_status) {
205   MakeCallableHelper(session->session, callable_options, out_handle,
206                      out_status);
207 }
208 
209 namespace {
RunCallableHelper(tensorflow::Session * session,int64_t handle,PyObject * feed_values,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_metadata)210 void RunCallableHelper(tensorflow::Session* session, int64_t handle,
211                        PyObject* feed_values, TF_Status* out_status,
212                        PyObjectVector* out_values, TF_Buffer* run_metadata) {
213   // Convert feed values to a vector of tensorflow::Tensor objects.
214   std::vector<Tensor> input_tensors;
215   Status s;
216   {
217     feed_values =
218         PySequence_Fast(feed_values, "feed_values must be a sequence");
219     if (feed_values == nullptr) return;
220     Safe_PyObjectPtr feed_values_holder(make_safe(feed_values));
221     Py_ssize_t len = PySequence_Fast_GET_SIZE(feed_values);
222     input_tensors.reserve(len);
223     for (Py_ssize_t i = 0; i < len; ++i) {
224       PyObject* elem = PySequence_Fast_GET_ITEM(feed_values, i);
225       if (!elem) {
226         Set_TF_Status_from_Status(
227             out_status, errors::Internal("Could not get feed value ", i));
228         return;
229       }
230       Tensor t;
231       s = NdarrayToTensor(elem, &t);
232       if (!s.ok()) {
233         Set_TF_Status_from_Status(out_status, s);
234         return;
235       }
236       input_tensors.push_back(std::move(t));
237     }
238   }
239 
240   // Allocate a RunMetadata protobuf object to receive the metadata,
241   // if the caller is expecting any.
242   std::unique_ptr<RunMetadata> run_metadata_proto;
243   if (run_metadata != nullptr) {
244     run_metadata_proto.reset(new RunMetadata);
245   }
246 
247   // Run the callable.
248   std::vector<Tensor> output_tensors;
249   Py_BEGIN_ALLOW_THREADS;
250   s = session->RunCallable(handle, input_tensors, &output_tensors,
251                            run_metadata_proto.get());
252   Py_END_ALLOW_THREADS;
253 
254   if (!s.ok()) {
255     Set_TF_Status_from_Status(out_status, s);
256     return;
257   }
258 
259   // If requested, serialize the RunMetadata to pass it back to the caller.
260   if (run_metadata != nullptr) {
261     s = MessageToBuffer(*run_metadata_proto, run_metadata);
262     if (!s.ok()) {
263       Set_TF_Status_from_Status(out_status, s);
264       return;
265     }
266   }
267 
268   // Convert results to NumPy arrays. Since this can fail, stage the
269   // results via a safe container that takes care of decreasing the
270   // reference count on failure.
271   std::vector<Safe_PyObjectPtr> py_outputs_safe;
272   py_outputs_safe.reserve(output_tensors.size());
273   for (const Tensor& output : output_tensors) {
274     PyObject* py_array;
275     s = TensorToNdarray(output, &py_array);
276     if (!s.ok()) {
277       Set_TF_Status_from_Status(out_status, s);
278       return;
279     }
280     py_outputs_safe.push_back(make_safe(py_array));
281   }
282 
283   // If we reach this point, we have successfully built a list of objects
284   // so we can release them from the safe container.
285   out_values->reserve(py_outputs_safe.size());
286   for (auto& output : py_outputs_safe) {
287     out_values->push_back(output.release());
288   }
289 }
290 }  // namespace
291 
TF_DeprecatedSessionRunCallable(TF_DeprecatedSession * session,int64_t handle,PyObject * feed_values,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_metadata)292 void TF_DeprecatedSessionRunCallable(TF_DeprecatedSession* session,
293                                      int64_t handle, PyObject* feed_values,
294                                      TF_Status* out_status,
295                                      PyObjectVector* out_values,
296                                      TF_Buffer* run_metadata) {
297   RunCallableHelper(session->session, handle, feed_values, out_status,
298                     out_values, run_metadata);
299   ClearDecrefCache();
300 }
TF_SessionRunCallable(TF_Session * session,int64_t handle,PyObject * feed_values,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_metadata)301 void TF_SessionRunCallable(TF_Session* session, int64_t handle,
302                            PyObject* feed_values, TF_Status* out_status,
303                            PyObjectVector* out_values,
304                            TF_Buffer* run_metadata) {
305   RunCallableHelper(session->session, handle, feed_values, out_status,
306                     out_values, run_metadata);
307   ClearDecrefCache();
308 }
309 
TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession * session,int64_t handle,TF_Status * out_status)310 void TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession* session,
311                                          int64_t handle,
312                                          TF_Status* out_status) {
313   Set_TF_Status_from_Status(out_status,
314                             session->session->ReleaseCallable(handle));
315 }
TF_SessionReleaseCallable(TF_Session * session,int64_t handle,TF_Status * out_status)316 void TF_SessionReleaseCallable(TF_Session* session, int64_t handle,
317                                TF_Status* out_status) {
318   Set_TF_Status_from_Status(out_status,
319                             session->session->ReleaseCallable(handle));
320 }
321 
322 // Wrapper for TF_PRunSetup that converts the arguments to appropriate types.
323 // If *out_status is OK, the caller becomes the owner of *out_handle.
TF_PRunSetup_wrapper(TF_DeprecatedSession * session,const NameVector & input_names,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,const char ** out_handle)324 void TF_PRunSetup_wrapper(TF_DeprecatedSession* session,
325                           const NameVector& input_names,
326                           const NameVector& output_names,
327                           const NameVector& target_nodes, TF_Status* out_status,
328                           const char** out_handle) {
329   Py_BEGIN_ALLOW_THREADS;
330   TF_PRunSetup(
331       session, const_cast<const char**>(input_names.data()), input_names.size(),
332       const_cast<const char**>(output_names.data()), output_names.size(),
333       const_cast<const char**>(target_nodes.data()), target_nodes.size(),
334       out_handle, out_status);
335   Py_END_ALLOW_THREADS;
336 }
337 
338 // Wrapper for TF_PRun that converts the arguments to appropriate types.
339 // If *out_status is OK, the caller becomes the owner of the PyObjects
340 // in *out_values.
TF_PRun_wrapper(TF_DeprecatedSession * session,const char * handle,PyObject * feed_dict,const NameVector & output_names,TF_Status * out_status,PyObjectVector * out_values)341 void TF_PRun_wrapper(TF_DeprecatedSession* session, const char* handle,
342                      PyObject* feed_dict, const NameVector& output_names,
343                      TF_Status* out_status, PyObjectVector* out_values) {
344   TF_Run_wrapper_helper(session, handle, nullptr, feed_dict, output_names,
345                         NameVector(), out_status, out_values, nullptr);
346   ClearDecrefCache();
347 }
348 
349 // Wrapper for TF_Reset that converts the string vectors to character arrays.
TF_Reset_wrapper(const TF_SessionOptions * opt,const NameVector & containers,TF_Status * out_status)350 void TF_Reset_wrapper(const TF_SessionOptions* opt,
351                       const NameVector& containers, TF_Status* out_status) {
352   TF_Reset(opt, const_cast<const char**>(containers.data()), containers.size(),
353            out_status);
354 }
355 
TF_SessionRun_wrapper_helper(TF_Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,TF_Buffer * run_metadata,TF_Status * out_status,std::vector<PyObject * > * py_outputs)356 void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
357                                   const TF_Buffer* run_options,
358                                   const std::vector<TF_Output>& inputs,
359                                   const std::vector<PyObject*>& input_ndarrays,
360                                   const std::vector<TF_Output>& outputs,
361                                   const std::vector<TF_Operation*>& targets,
362                                   TF_Buffer* run_metadata,
363                                   TF_Status* out_status,
364                                   std::vector<PyObject*>* py_outputs) {
365   DCHECK_EQ(inputs.size(), input_ndarrays.size());
366   DCHECK(py_outputs != nullptr);
367   DCHECK(py_outputs->empty());
368   Status s;
369 
370   // Convert input ndarray PyObjects to TF_Tensors. We maintain a continuous
371   // array of TF_Tensor*s as well as scoped containers to make sure they're
372   // cleaned up properly.
373   //
374   // Memory management:
375   // PyArrayToTF_Tensor() creates a new ndarray PyObject from the input
376   // ndarray. We manage the new ndarray's lifetime in order to keep the
377   // underlying data buffer alive (the new ndarray also guarantees a contiguous
378   // data buffer). The new ndarray's data buffer is used to create the
379   // corresponding TF_Tensor. The TF_Tensor's deallocator will queue the new
380   // ndarray to be decref'd by the next ClearDecrefCache() call (we can't call
381   // Py_DECREF in the deallocator directly because the GIL must be held).
382   //
383   // Note that TF_Tensor may directly delegate its data and deallocator to a
384   // TensorBuffer, which may outlive the TF_Tensor (e.g. if the tensor gets
385   // queued or assigned to a variable).
386   TF_TensorVector input_vals;
387   std::vector<Safe_TF_TensorPtr> input_vals_safe;
388   for (PyObject* ndarray : input_ndarrays) {
389     input_vals_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
390     s = PyArrayToTF_Tensor(ndarray, &input_vals_safe.back());
391     if (!s.ok()) {
392       Set_TF_Status_from_Status(out_status, s);
393       return;
394     }
395     input_vals.push_back(input_vals_safe.back().get());
396   }
397 
398   // Allocate space for output TF_Tensor*s
399   TF_TensorVector output_vals(outputs.size());
400 
401   // Clear up any unused memory leftover from previous runs
402   ClearDecrefCache();
403 
404   // Call TF_SessionRun() (and release GIL during execution)
405   Py_BEGIN_ALLOW_THREADS;
406   if (handle == nullptr) {
407     TF_SessionRun(session, run_options, inputs.data(), input_vals.data(),
408                   inputs.size(), outputs.data(), output_vals.data(),
409                   outputs.size(), targets.data(), targets.size(), run_metadata,
410                   out_status);
411   } else {
412     TF_SessionPRun(session, handle, inputs.data(), input_vals.data(),
413                    inputs.size(), outputs.data(), output_vals.data(),
414                    outputs.size(), targets.data(), targets.size(), out_status);
415   }
416   Py_END_ALLOW_THREADS;
417 
418   // Create scoped containers for output tensors
419   std::vector<Safe_TF_TensorPtr> output_vals_safe;
420   for (TF_Tensor* output : output_vals) {
421     output_vals_safe.emplace_back(make_safe(output));
422   }
423 
424   // Convert outputs to ndarrays (in scoped containers)
425   std::vector<Safe_PyObjectPtr> py_outputs_safe;
426   for (size_t i = 0; i < outputs.size(); ++i) {
427     PyObject* py_array;
428     s = TF_TensorToPyArray(std::move(output_vals_safe[i]), &py_array);
429     if (!s.ok()) {
430       Set_TF_Status_from_Status(out_status, s);
431       return;
432     }
433     py_outputs_safe.emplace_back(make_safe(py_array));
434   }
435 
436   // If we reach this point, we have successfully built a list of objects so we
437   // can release them from the safe container into the return vector.
438   for (size_t i = 0; i < outputs.size(); ++i) {
439     py_outputs->push_back(py_outputs_safe[i].release());
440   }
441 }
442 
TF_SessionRun_wrapper(TF_Session * session,const TF_Buffer * run_options,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,TF_Buffer * run_metadata,TF_Status * out_status,std::vector<PyObject * > * py_outputs)443 void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options,
444                            const std::vector<TF_Output>& inputs,
445                            const std::vector<PyObject*>& input_ndarrays,
446                            const std::vector<TF_Output>& outputs,
447                            const std::vector<TF_Operation*>& targets,
448                            TF_Buffer* run_metadata, TF_Status* out_status,
449                            std::vector<PyObject*>* py_outputs) {
450   TF_SessionRun_wrapper_helper(session, nullptr, run_options, inputs,
451                                input_ndarrays, outputs, targets, run_metadata,
452                                out_status, py_outputs);
453   // Release any unused ndarray references (see memory management comment in
454   // TF_SessionRun_wrapper_helper)
455   ClearDecrefCache();
456 }
457 
EqualGraphDefWrapper(const string & actual,const string & expected)458 string EqualGraphDefWrapper(const string& actual, const string& expected) {
459   GraphDef actual_def;
460   if (!actual_def.ParseFromString(actual)) {
461     return "actual is not a valid serialized GraphDef";
462   }
463   GraphDef expected_def;
464   if (!expected_def.ParseFromString(expected)) {
465     return "expected is not a valid serialized GraphDef";
466   }
467   string diff;
468   return EqualGraphDef(actual_def, expected_def, &diff) ? "" : diff;
469 }
470 
EqualAttrValueWrapper(const string & actual,const string & expected)471 string EqualAttrValueWrapper(const string& actual, const string& expected) {
472   AttrValue actual_attr_value;
473   if (!actual_attr_value.ParseFromString(actual)) {
474     return "actual is not a valid serialized AttrValue";
475   }
476 
477   AttrValue expected_attr_value;
478   if (!expected_attr_value.ParseFromString(expected)) {
479     return "expected is not a valid serialized AttrValue";
480   }
481 
482   string diff;
483   if (!AreAttrValuesEqual(actual_attr_value, expected_attr_value)) {
484     diff = strings::Printf(
485         "Actual AttrValue %s does not match Expected AttrValue %s.",
486         SummarizeAttrValue(actual_attr_value).c_str(),
487         SummarizeAttrValue(expected_attr_value).c_str());
488   }
489   return diff;
490 }
491 
492 // Return value set to 6 inlined elements so it fits in a 64-byte cache line.
TF_GraphGetTensorShapeHelper(TF_Graph * graph,TF_Output output,TF_Status * out_status,bool * unknown_shape)493 tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper(
494     TF_Graph* graph, TF_Output output, TF_Status* out_status,
495     bool* unknown_shape) {
496   // Allocate a single variable for holding the result for RVO.
497   tensorflow::gtl::InlinedVector<int64_t, 6> result;
498   *unknown_shape = false;
499   int num_dims = TF_GraphGetTensorNumDims(graph, output, out_status);
500   if (TF_GetCode(out_status) != TF_OK) {
501     return result;
502   }
503   // If shape is unknown, set boolean and return.
504   if (num_dims == -1) {
505     *unknown_shape = true;
506     return result;
507   }
508 
509   // If shape is a scalar, avoid another C call and just return {}.
510   if (num_dims == 0) {
511     return result;
512   }
513 
514   result.resize(num_dims);
515   TF_GraphGetTensorShape(graph, output, result.data(), num_dims, out_status);
516   return result;
517 }
518 
TF_SessionPRunSetup_wrapper(TF_Session * session,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,const char ** out_handle,TF_Status * out_status)519 void TF_SessionPRunSetup_wrapper(TF_Session* session,
520                                  const std::vector<TF_Output>& inputs,
521                                  const std::vector<TF_Output>& outputs,
522                                  const std::vector<TF_Operation*>& targets,
523                                  const char** out_handle,
524                                  TF_Status* out_status) {
525   // Call TF_SessionPRunSetup() (and release GIL during execution)
526   Py_BEGIN_ALLOW_THREADS;
527   TF_SessionPRunSetup(session, inputs.data(), inputs.size(), outputs.data(),
528                       outputs.size(), targets.data(), targets.size(),
529                       out_handle, out_status);
530   Py_END_ALLOW_THREADS;
531 }
532 
TF_SessionPRun_wrapper(TF_Session * session,const char * handle,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,TF_Status * out_status,std::vector<PyObject * > * py_outputs)533 void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
534                             const std::vector<TF_Output>& inputs,
535                             const std::vector<PyObject*>& input_ndarrays,
536                             const std::vector<TF_Output>& outputs,
537                             TF_Status* out_status,
538                             std::vector<PyObject*>* py_outputs) {
539   const std::vector<TF_Operation*> targets;
540   TF_SessionRun_wrapper_helper(session, handle,
541                                nullptr,  // run_options
542                                inputs, input_ndarrays, outputs, targets,
543                                nullptr,  // run_metadata
544                                out_status, py_outputs);
545   // Release any unused ndarray references (see memory management comment in
546   // TF_SessionRun_wrapper_helper)
547   ClearDecrefCache();
548 }
549 
GetOperationInputs(TF_Operation * oper)550 std::vector<TF_Output> GetOperationInputs(TF_Operation* oper) {
551   int num_inputs = TF_OperationNumInputs(oper);
552   std::vector<TF_Output> inputs(num_inputs);
553   for (int i = 0; i < num_inputs; ++i) {
554     inputs[i] = TF_OperationInput({oper, i});
555   }
556   return inputs;
557 }
558 
TF_OperationGetControlInputs_wrapper(TF_Operation * oper)559 std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
560     TF_Operation* oper) {
561   std::vector<TF_Operation*> control_inputs(TF_OperationNumControlInputs(oper));
562   TF_OperationGetControlInputs(oper, control_inputs.data(),
563                                control_inputs.size());
564   return control_inputs;
565 }
566 
TF_OperationGetControlOutputs_wrapper(TF_Operation * oper)567 std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper(
568     TF_Operation* oper) {
569   std::vector<TF_Operation*> control_outputs(
570       TF_OperationNumControlOutputs(oper));
571   TF_OperationGetControlOutputs(oper, control_outputs.data(),
572                                 control_outputs.size());
573   return control_outputs;
574 }
575 
TF_OperationOutputConsumers_wrapper(TF_Output oper_out)576 std::vector<const char*> TF_OperationOutputConsumers_wrapper(
577     TF_Output oper_out) {
578   int num_consumers = TF_OperationOutputNumConsumers(oper_out);
579   std::vector<TF_Input> consumers(num_consumers);
580   TF_OperationOutputConsumers(oper_out, consumers.data(), num_consumers);
581 
582   std::vector<const char*> consumer_names(num_consumers);
583   for (int i = 0; i < num_consumers; ++i) {
584     consumer_names[i] = TF_OperationName(consumers[i].oper);
585   }
586   return consumer_names;
587 }
588 
TF_GraphToFunction_wrapper(const TF_Graph * fn_body,const char * fn_name,bool append_hash_to_fn_name,const std::vector<TF_Operation * > * opers,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const NameVector & output_names,const std::vector<TF_Operation * > * control_outputs,const NameVector & control_output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * out_status)589 TF_Function* TF_GraphToFunction_wrapper(
590     const TF_Graph* fn_body, const char* fn_name, bool append_hash_to_fn_name,
591     const std::vector<TF_Operation*>* opers,
592     const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs,
593     const NameVector& output_names,
594     const std::vector<TF_Operation*>* control_outputs,
595     const NameVector& control_output_names, const TF_FunctionOptions* opts,
596     const char* description, TF_Status* out_status) {
597   if (!output_names.empty() && output_names.size() != outputs.size()) {
598     Set_TF_Status_from_Status(
599         out_status,
600         errors::InvalidArgument(
601             "output names must be either empty or equal in size to outputs. ",
602             "output names size = ", output_names.size(),
603             " outputs size = ", outputs.size()));
604     return nullptr;
605   }
606 
607   int nopers = -1;
608   const TF_Operation* const* opers_array = nullptr;
609   if (opers != nullptr) {
610     nopers = opers->size();
611     opers_array = opers->data();
612   }
613 
614   const char** output_names_ptr =
615       output_names.empty() ? nullptr
616                            : const_cast<const char**>(output_names.data());
617 
618   const char** control_output_names_ptr =
619       control_output_names.empty()
620           ? nullptr
621           : const_cast<const char**>(control_output_names.data());
622 
623   return TF_GraphToFunctionWithControlOutputs(
624       fn_body, fn_name, append_hash_to_fn_name, nopers, opers_array,
625       inputs.size(), inputs.data(), outputs.size(), outputs.data(),
626       output_names_ptr,
627       control_outputs == nullptr ? 0 : control_outputs->size(),
628       control_outputs == nullptr ? nullptr : control_outputs->data(),
629       control_output_names_ptr, opts, description, out_status);
630 }
631 
TF_GraphSetOutputHandleShapesAndTypes_wrapper(TF_Graph * graph,TF_Output output,const std::vector<std::vector<int64_t>> & shapes,const std::vector<int> & ranks,const std::vector<TF_DataType> & types,TF_Status * status)632 void TF_GraphSetOutputHandleShapesAndTypes_wrapper(
633     TF_Graph* graph, TF_Output output,
634     const std::vector<std::vector<int64_t>>& shapes,
635     const std::vector<int>& ranks, const std::vector<TF_DataType>& types,
636     TF_Status* status) {
637   std::vector<const int64_t*> shapes_pointers(shapes.size());
638   for (int i = 0; i < shapes.size(); ++i) {
639     shapes_pointers[i] = ranks[i] <= 0 ? nullptr : &shapes[i][0];
640   }
641   TF_GraphSetOutputHandleShapesAndTypes(graph, output, shapes.size(),
642                                         shapes_pointers.data(), ranks.data(),
643                                         types.data(), status);
644 }
645 
TF_GraphSetTensorShape_wrapper(TF_Graph * graph,TF_Output output,const std::vector<int64_t> & dims,bool unknown_shape,TF_Status * status)646 void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
647                                     const std::vector<int64_t>& dims,
648                                     bool unknown_shape, TF_Status* status) {
649   if (unknown_shape) {
650     TF_GraphSetTensorShape(graph, output, nullptr, -1, status);
651     return;
652   }
653   TF_GraphSetTensorShape(graph, output, dims.data(), dims.size(), status);
654 }
655 
TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(TF_ImportGraphDefResults * results)656 std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
657     TF_ImportGraphDefResults* results) {
658   int num_missing_unused_input_mappings;
659   const char** src_names;
660   int* src_indexes;
661   TF_ImportGraphDefResultsMissingUnusedInputMappings(
662       results, &num_missing_unused_input_mappings, &src_names, &src_indexes);
663   std::vector<string> input_strs(num_missing_unused_input_mappings);
664   for (int i = 0; i < num_missing_unused_input_mappings; ++i) {
665     input_strs[i] = TensorId(src_names[i], src_indexes[i]).ToString();
666   }
667   return input_strs;
668 }
669 
TF_TryEvaluateConstant_wrapper(TF_Graph * graph,TF_Output output,TF_Status * status)670 PyObject* TF_TryEvaluateConstant_wrapper(TF_Graph* graph, TF_Output output,
671                                          TF_Status* status) {
672   TF_Tensor* result_tensor;
673   bool evaluated =
674       TF_TryEvaluateConstant(graph, output, &result_tensor, status);
675   if (!evaluated || TF_GetCode(status) != TF_OK) Py_RETURN_NONE;
676 
677   Safe_TF_TensorPtr safe_result_tensor(result_tensor);
678   PyObject* out;
679   Status s = TF_TensorToPyArray(std::move(safe_result_tensor), &out);
680   Set_TF_Status_from_Status(status, s);
681   if (!s.ok()) Py_RETURN_NONE;
682   return out;
683 }
684 
685 }  // namespace tensorflow
686