• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/client/tf_session_helper.h"
17 
18 #include <cstring>
19 
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/c_api_internal.h"
22 #include "tensorflow/c/tf_status_helper.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/log_memory.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/graph/tensor_id.h"
29 #include "tensorflow/core/lib/core/coding.h"
30 #include "tensorflow/core/lib/strings/stringprintf.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/equal_graph_def.h"
33 #include "tensorflow/python/client/session_ref.h"
34 #include "tensorflow/python/lib/core/ndarray_tensor.h"
35 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
36 #include "tensorflow/python/lib/core/safe_ptr.h"
37 
38 namespace tensorflow {
39 
40 namespace {
41 
42 static const char* kFeedDictErrorMsg =
43     "feed_dict must be a dictionary mapping strings to NumPy arrays.";
44 }  // end namespace
45 
TF_NewSessionRef(TF_Graph * graph,const TF_SessionOptions * opts,TF_Status * status)46 TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts,
47                              TF_Status* status) {
48   TF_Session* tf_session = TF_NewSession(graph, opts, status);
49   if (tf_session == nullptr) {
50     return nullptr;
51   }
52 
53   Session* session = reinterpret_cast<Session*>(tf_session->session);
54   SessionRef* session_ref = new SessionRef(session);
55   tf_session->session = session_ref;
56   return tf_session;
57 }
58 
TF_Run_wrapper_helper(TF_DeprecatedSession * session,const char * handle,const TF_Buffer * run_options,PyObject * feed_dict,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_outputs)59 void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
60                            const TF_Buffer* run_options, PyObject* feed_dict,
61                            const NameVector& output_names,
62                            const NameVector& target_nodes,
63                            TF_Status* out_status, PyObjectVector* out_values,
64                            TF_Buffer* run_outputs) {
65   // 1. Convert the feed inputs to the appropriate form for TF_Run.
66   if (!PyDict_Check(feed_dict)) {
67     Set_TF_Status_from_Status(out_status,
68                               errors::InvalidArgument(kFeedDictErrorMsg));
69     return;
70   }
71 
72   NameVector input_names;
73   std::vector<Safe_TF_TensorPtr> inputs_safe;  // Used to delete tensors.
74   TF_TensorVector inputs_unsafe;  // Used to contain the arg to TF_Run.
75 
76   PyObject* key;
77   PyObject* value;
78   Py_ssize_t pos = 0;
79   int index = 0;
80   Status s;
81 
82   while (PyDict_Next(feed_dict, &pos, &key, &value)) {
83     char* key_string = PyBytes_AsString(key);
84     if (!key_string) {
85       Set_TF_Status_from_Status(out_status,
86                                 errors::InvalidArgument(kFeedDictErrorMsg));
87       return;
88     }
89     input_names.push_back(key_string);
90 
91     inputs_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
92     s = NdarrayToTensor(nullptr /*ctx*/, value, &inputs_safe.back());
93     if (!s.ok()) {
94       Set_TF_Status_from_Status(out_status, s);
95       return;
96     }
97     inputs_unsafe.push_back(inputs_safe.back().get());
98     ++index;
99   }
100 
101   // 2. Allocate a container for the output data.
102   TF_TensorVector outputs(output_names.size());
103 
104   // In case any tensors were leftover from previous runs we might as well clear
105   // them here.
106   ClearDecrefCache();
107 
108   // 3. Actually call TF_Run().
109   Py_BEGIN_ALLOW_THREADS;
110   if (handle == nullptr) {
111     TF_Run(session, run_options, input_names.data(), inputs_unsafe.data(),
112            input_names.size(), const_cast<const char**>(output_names.data()),
113            outputs.data(), output_names.size(),
114            const_cast<const char**>(target_nodes.data()), target_nodes.size(),
115            run_outputs, out_status);
116   } else {
117     TF_PRun(session, handle, input_names.data(), inputs_unsafe.data(),
118             input_names.size(), const_cast<const char**>(output_names.data()),
119             outputs.data(), output_names.size(),
120             const_cast<const char**>(target_nodes.data()), target_nodes.size(),
121             out_status);
122   }
123 
124   Py_END_ALLOW_THREADS;
125 
126   // Decref any numpy arrays we are not using anymore.
127   ClearDecrefCache();
128 
129   if (TF_GetCode(out_status) != TF_OK) {
130     return;
131   }
132 
133   // 4. We now own the fetched tensors, so set up a safe container to
134   // delete them when we exit this scope.
135   std::vector<Safe_TF_TensorPtr> tf_outputs_safe;
136   for (const auto& output : outputs) {
137     tf_outputs_safe.emplace_back(make_safe(output));
138   }
139 
140   // 5. Convert the fetched tensors into numpy ndarrays. Store them in a safe
141   // container so that we do not leak
142   std::vector<Safe_PyObjectPtr> py_outputs_safe;
143   for (size_t i = 0; i < output_names.size(); ++i) {
144     PyObject* py_array;
145     s = TF_TensorToPyArray(std::move(tf_outputs_safe[i]), &py_array);
146     if (!s.ok()) {
147       Set_TF_Status_from_Status(out_status, s);
148       return;
149     }
150     py_outputs_safe.emplace_back(
151         make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array))));
152   }
153 
154   // 6. If we reach this point, we have successfully built a list of objects
155   // so we can release them from the safe container.
156   for (auto& output : py_outputs_safe) {
157     out_values->push_back(output.release());
158   }
159 }
160 
161 // Wrapper for TF_Run that converts the arguments to appropriate types.
162 // If *out_status is OK, the caller becomes the owner of the PyObjects
163 // in *out_values.
TF_Run_wrapper(TF_DeprecatedSession * session,const TF_Buffer * run_options,PyObject * feed_dict,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_outputs)164 void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options,
165                     PyObject* feed_dict, const NameVector& output_names,
166                     const NameVector& target_nodes, TF_Status* out_status,
167                     PyObjectVector* out_values, TF_Buffer* run_outputs) {
168   TF_Run_wrapper_helper(session, nullptr, run_options, feed_dict, output_names,
169                         target_nodes, out_status, out_values, run_outputs);
170   ClearDecrefCache();
171 }
172 
173 namespace {
MakeCallableHelper(tensorflow::Session * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * out_status)174 void MakeCallableHelper(tensorflow::Session* session,
175                         const TF_Buffer* callable_options, int64_t* out_handle,
176                         TF_Status* out_status) {
177   tensorflow::CallableOptions callable_options_proto;
178   if (callable_options != nullptr &&
179       !callable_options_proto.ParseFromArray(callable_options->data,
180                                              callable_options->length)) {
181     Set_TF_Status_from_Status(
182         out_status,
183         errors::InvalidArgument("Unparseable CallableOptions proto"));
184     return;
185   }
186   tensorflow::Session::CallableHandle handle;
187   Status s = session->MakeCallable(callable_options_proto, &handle);
188   if (!s.ok()) {
189     Set_TF_Status_from_Status(out_status, s);
190     return;
191   }
192   *out_handle = handle;
193 }
194 }  // namespace
195 
TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * status)196 void TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession* session,
197                                       const TF_Buffer* callable_options,
198                                       int64_t* out_handle, TF_Status* status) {
199   MakeCallableHelper(session->session, callable_options, out_handle, status);
200 }
TF_SessionMakeCallable(TF_Session * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * status)201 void TF_SessionMakeCallable(TF_Session* session,
202                             const TF_Buffer* callable_options,
203                             int64_t* out_handle, TF_Status* status) {
204   MakeCallableHelper(session->session, callable_options, out_handle, status);
205 }
206 
207 namespace {
RunCallableHelper(tensorflow::Session * session,int64_t handle,PyObject * feed_values,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_metadata)208 void RunCallableHelper(tensorflow::Session* session, int64_t handle,
209                        PyObject* feed_values, TF_Status* out_status,
210                        PyObjectVector* out_values, TF_Buffer* run_metadata) {
211   // Convert feed values to a vector of tensorflow::Tensor objects.
212   std::vector<Tensor> input_tensors;
213   Status s;
214   {
215     feed_values =
216         PySequence_Fast(feed_values, "feed_values must be a sequence");
217     if (feed_values == nullptr) return;
218     Safe_PyObjectPtr feed_values_holder(make_safe(feed_values));
219     Py_ssize_t len = PySequence_Fast_GET_SIZE(feed_values);
220     input_tensors.reserve(len);
221     for (Py_ssize_t i = 0; i < len; ++i) {
222       PyObject* elem = PySequence_Fast_GET_ITEM(feed_values, i);
223       if (!elem) {
224         Set_TF_Status_from_Status(
225             out_status, errors::Internal("Could not get feed value ", i));
226         return;
227       }
228       Tensor t;
229       s = NdarrayToTensor(elem, &t);
230       if (!s.ok()) {
231         Set_TF_Status_from_Status(out_status, s);
232         return;
233       }
234       input_tensors.push_back(std::move(t));
235     }
236   }
237 
238   RunMetadata run_metadata_proto;
239 
240   // Run the callable.
241   std::vector<Tensor> output_tensors;
242   Py_BEGIN_ALLOW_THREADS;
243   s = session->RunCallable(handle, input_tensors, &output_tensors,
244                            &run_metadata_proto);
245   Py_END_ALLOW_THREADS;
246 
247   if (!s.ok()) {
248     Set_TF_Status_from_Status(out_status, s);
249     return;
250   }
251 
252   // If requested, serialize the RunMetadata to pass it back to the caller.
253   if (run_metadata != nullptr) {
254     s = MessageToBuffer(run_metadata_proto, run_metadata);
255     if (!s.ok()) {
256       Set_TF_Status_from_Status(out_status, s);
257       return;
258     }
259   }
260 
261   // Convert results to NumPy arrays. Since this can fail, stage the
262   // results via a safe container that takes care of decreasing the
263   // reference count on failure.
264   std::vector<Safe_PyObjectPtr> py_outputs_safe;
265   py_outputs_safe.reserve(output_tensors.size());
266   for (const Tensor& output : output_tensors) {
267     PyObject* py_array;
268     s = TensorToNdarray(output, &py_array);
269     if (!s.ok()) {
270       Set_TF_Status_from_Status(out_status, s);
271       return;
272     }
273     py_outputs_safe.push_back(
274         make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array))));
275   }
276 
277   // If we reach this point, we have successfully built a list of objects
278   // so we can release them from the safe container.
279   out_values->reserve(py_outputs_safe.size());
280   for (auto& output : py_outputs_safe) {
281     out_values->push_back(output.release());
282   }
283 }
284 }  // namespace
285 
TF_DeprecatedSessionRunCallable(TF_DeprecatedSession * session,int64_t handle,PyObject * feed_values,PyObjectVector * out_values,TF_Buffer * run_metadata,TF_Status * status)286 void TF_DeprecatedSessionRunCallable(TF_DeprecatedSession* session,
287                                      int64_t handle, PyObject* feed_values,
288                                      PyObjectVector* out_values,
289                                      TF_Buffer* run_metadata,
290                                      TF_Status* status) {
291   RunCallableHelper(session->session, handle, feed_values, status, out_values,
292                     run_metadata);
293   ClearDecrefCache();
294 }
TF_SessionRunCallable(TF_Session * session,int64_t handle,PyObject * feed_values,PyObjectVector * out_values,TF_Buffer * run_metadata,TF_Status * status)295 void TF_SessionRunCallable(TF_Session* session, int64_t handle,
296                            PyObject* feed_values, PyObjectVector* out_values,
297                            TF_Buffer* run_metadata, TF_Status* status) {
298   RunCallableHelper(session->session, handle, feed_values, status, out_values,
299                     run_metadata);
300   ClearDecrefCache();
301 }
302 
TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession * session,int64_t handle,TF_Status * status)303 void TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession* session,
304                                          int64_t handle, TF_Status* status) {
305   Set_TF_Status_from_Status(status, session->session->ReleaseCallable(handle));
306 }
TF_SessionReleaseCallable(TF_Session * session,int64_t handle,TF_Status * status)307 void TF_SessionReleaseCallable(TF_Session* session, int64_t handle,
308                                TF_Status* status) {
309   Set_TF_Status_from_Status(status, session->session->ReleaseCallable(handle));
310 }
311 
312 // Wrapper for TF_PRunSetup that converts the arguments to appropriate types.
313 // If *out_status is OK, the caller becomes the owner of *out_handle.
TF_PRunSetup_wrapper(TF_DeprecatedSession * session,const NameVector & input_names,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,const char ** out_handle)314 void TF_PRunSetup_wrapper(TF_DeprecatedSession* session,
315                           const NameVector& input_names,
316                           const NameVector& output_names,
317                           const NameVector& target_nodes, TF_Status* out_status,
318                           const char** out_handle) {
319   Py_BEGIN_ALLOW_THREADS;
320   TF_PRunSetup(
321       session, const_cast<const char**>(input_names.data()), input_names.size(),
322       const_cast<const char**>(output_names.data()), output_names.size(),
323       const_cast<const char**>(target_nodes.data()), target_nodes.size(),
324       out_handle, out_status);
325   Py_END_ALLOW_THREADS;
326 }
327 
328 // Wrapper for TF_PRun that converts the arguments to appropriate types.
329 // If *out_status is OK, the caller becomes the owner of the PyObjects
330 // in *out_values.
TF_PRun_wrapper(TF_DeprecatedSession * session,const char * handle,PyObject * feed_dict,const NameVector & output_names,TF_Status * out_status,PyObjectVector * out_values)331 void TF_PRun_wrapper(TF_DeprecatedSession* session, const char* handle,
332                      PyObject* feed_dict, const NameVector& output_names,
333                      TF_Status* out_status, PyObjectVector* out_values) {
334   TF_Run_wrapper_helper(session, handle, nullptr, feed_dict, output_names,
335                         NameVector(), out_status, out_values, nullptr);
336   ClearDecrefCache();
337 }
338 
339 // Wrapper for TF_Reset that converts the string vectors to character arrays.
TF_Reset_wrapper(const TF_SessionOptions * opt,const NameVector & containers,TF_Status * status)340 void TF_Reset_wrapper(const TF_SessionOptions* opt,
341                       const NameVector& containers, TF_Status* status) {
342   TF_Reset(opt, const_cast<const char**>(containers.data()), containers.size(),
343            status);
344 }
345 
TF_SessionRun_wrapper_helper(TF_Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,TF_Buffer * run_metadata,TF_Status * out_status,std::vector<PyObject * > * py_outputs)346 void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
347                                   const TF_Buffer* run_options,
348                                   const std::vector<TF_Output>& inputs,
349                                   const std::vector<PyObject*>& input_ndarrays,
350                                   const std::vector<TF_Output>& outputs,
351                                   const std::vector<TF_Operation*>& targets,
352                                   TF_Buffer* run_metadata,
353                                   TF_Status* out_status,
354                                   std::vector<PyObject*>* py_outputs) {
355   DCHECK_EQ(inputs.size(), input_ndarrays.size());
356   DCHECK(py_outputs != nullptr);
357   DCHECK(py_outputs->empty());
358   Status s;
359 
360   // Convert input ndarray PyObjects to TF_Tensors. We maintain a continuous
361   // array of TF_Tensor*s as well as scoped containers to make sure they're
362   // cleaned up properly.
363   //
364   // Memory management:
365   // NdarrayToTensor() creates a new ndarray PyObject from the input
366   // ndarray. We manage the new ndarray's lifetime in order to keep the
367   // underlying data buffer alive (the new ndarray also guarantees a contiguous
368   // data buffer). The new ndarray's data buffer is used to create the
369   // corresponding TF_Tensor. The TF_Tensor's deallocator will queue the new
370   // ndarray to be decref'd by the next ClearDecrefCache() call (we can't call
371   // Py_DECREF in the deallocator directly because the GIL must be held).
372   //
373   // Note that TF_Tensor may directly delegate its data and deallocator to a
374   // TensorBuffer, which may outlive the TF_Tensor (e.g. if the tensor gets
375   // queued or assigned to a variable).
376   TF_TensorVector input_vals;
377   std::vector<Safe_TF_TensorPtr> input_vals_safe;
378   for (PyObject* ndarray : input_ndarrays) {
379     input_vals_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
380     s = NdarrayToTensor(nullptr, ndarray, &input_vals_safe.back());
381     if (!s.ok()) {
382       Set_TF_Status_from_Status(out_status, s);
383       return;
384     }
385     input_vals.push_back(input_vals_safe.back().get());
386   }
387 
388   // Allocate space for output TF_Tensor*s
389   TF_TensorVector output_vals(outputs.size());
390 
391   // Clear up any unused memory leftover from previous runs
392   ClearDecrefCache();
393 
394   // Call TF_SessionRun() (and release GIL during execution)
395   Py_BEGIN_ALLOW_THREADS;
396   if (handle == nullptr) {
397     TF_SessionRun(session, run_options, inputs.data(), input_vals.data(),
398                   inputs.size(), outputs.data(), output_vals.data(),
399                   outputs.size(), targets.data(), targets.size(), run_metadata,
400                   out_status);
401   } else {
402     TF_SessionPRun(session, handle, inputs.data(), input_vals.data(),
403                    inputs.size(), outputs.data(), output_vals.data(),
404                    outputs.size(), targets.data(), targets.size(), out_status);
405   }
406   Py_END_ALLOW_THREADS;
407 
408   // Create scoped containers for output tensors
409   std::vector<Safe_TF_TensorPtr> output_vals_safe;
410   for (TF_Tensor* output : output_vals) {
411     output_vals_safe.emplace_back(make_safe(output));
412   }
413 
414   // Convert outputs to ndarrays (in scoped containers)
415   std::vector<Safe_PyObjectPtr> py_outputs_safe;
416   for (size_t i = 0; i < outputs.size(); ++i) {
417     PyObject* py_array;
418     s = TF_TensorToPyArray(std::move(output_vals_safe[i]), &py_array);
419     if (!s.ok()) {
420       Set_TF_Status_from_Status(out_status, s);
421       return;
422     }
423     py_outputs_safe.emplace_back(
424         make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array))));
425   }
426 
427   // If we reach this point, we have successfully built a list of objects so we
428   // can release them from the safe container into the return vector.
429   for (size_t i = 0; i < outputs.size(); ++i) {
430     py_outputs->push_back(py_outputs_safe[i].release());
431   }
432 }
433 
TF_SessionRun_wrapper(TF_Session * session,const TF_Buffer * run_options,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,TF_Buffer * run_metadata,TF_Status * out_status,std::vector<PyObject * > * py_outputs)434 void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options,
435                            const std::vector<TF_Output>& inputs,
436                            const std::vector<PyObject*>& input_ndarrays,
437                            const std::vector<TF_Output>& outputs,
438                            const std::vector<TF_Operation*>& targets,
439                            TF_Buffer* run_metadata, TF_Status* out_status,
440                            std::vector<PyObject*>* py_outputs) {
441   TF_SessionRun_wrapper_helper(session, nullptr, run_options, inputs,
442                                input_ndarrays, outputs, targets, run_metadata,
443                                out_status, py_outputs);
444   // Release any unused ndarray references (see memory management comment in
445   // TF_SessionRun_wrapper_helper)
446   ClearDecrefCache();
447 }
448 
EqualGraphDefWrapper(const string & actual,const string & expected)449 string EqualGraphDefWrapper(const string& actual, const string& expected) {
450   GraphDef actual_def;
451   if (!actual_def.ParseFromString(actual)) {
452     return "actual is not a valid serialized GraphDef";
453   }
454   GraphDef expected_def;
455   if (!expected_def.ParseFromString(expected)) {
456     return "expected is not a valid serialized GraphDef";
457   }
458   string diff;
459   return EqualGraphDef(actual_def, expected_def, &diff) ? "" : diff;
460 }
461 
EqualAttrValueWrapper(const string & actual,const string & expected)462 string EqualAttrValueWrapper(const string& actual, const string& expected) {
463   AttrValue actual_attr_value;
464   if (!actual_attr_value.ParseFromString(actual)) {
465     return "actual is not a valid serialized AttrValue";
466   }
467 
468   AttrValue expected_attr_value;
469   if (!expected_attr_value.ParseFromString(expected)) {
470     return "expected is not a valid serialized AttrValue";
471   }
472 
473   string diff;
474   if (!AreAttrValuesEqual(actual_attr_value, expected_attr_value)) {
475     diff = strings::Printf(
476         "Actual AttrValue %s does not match Expected AttrValue %s.",
477         SummarizeAttrValue(actual_attr_value).c_str(),
478         SummarizeAttrValue(expected_attr_value).c_str());
479   }
480   return diff;
481 }
482 
483 // Return value set to 6 inlined elements so it fits in a 64-byte cache line.
TF_GraphGetTensorShapeHelper(TF_Graph * graph,TF_Output output,TF_Status * out_status,bool * unknown_shape)484 tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper(
485     TF_Graph* graph, TF_Output output, TF_Status* out_status,
486     bool* unknown_shape) {
487   // Allocate a single variable for holding the result for RVO.
488   tensorflow::gtl::InlinedVector<int64_t, 6> result;
489   *unknown_shape = false;
490   int num_dims = TF_GraphGetTensorNumDims(graph, output, out_status);
491   if (TF_GetCode(out_status) != TF_OK) {
492     return result;
493   }
494   // If shape is unknown, set boolean and return.
495   if (num_dims == -1) {
496     *unknown_shape = true;
497     return result;
498   }
499 
500   // If shape is a scalar, avoid another C call and just return {}.
501   if (num_dims == 0) {
502     return result;
503   }
504 
505   result.resize(num_dims);
506   TF_GraphGetTensorShape(graph, output, result.data(), num_dims, out_status);
507   return result;
508 }
509 
TF_SessionPRunSetup_wrapper(TF_Session * session,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,const char ** out_handle,TF_Status * out_status)510 void TF_SessionPRunSetup_wrapper(TF_Session* session,
511                                  const std::vector<TF_Output>& inputs,
512                                  const std::vector<TF_Output>& outputs,
513                                  const std::vector<TF_Operation*>& targets,
514                                  const char** out_handle,
515                                  TF_Status* out_status) {
516   // Call TF_SessionPRunSetup() (and release GIL during execution)
517   Py_BEGIN_ALLOW_THREADS;
518   TF_SessionPRunSetup(session, inputs.data(), inputs.size(), outputs.data(),
519                       outputs.size(), targets.data(), targets.size(),
520                       out_handle, out_status);
521   Py_END_ALLOW_THREADS;
522 }
523 
TF_SessionPRun_wrapper(TF_Session * session,const char * handle,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,TF_Status * out_status,std::vector<PyObject * > * py_outputs)524 void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
525                             const std::vector<TF_Output>& inputs,
526                             const std::vector<PyObject*>& input_ndarrays,
527                             const std::vector<TF_Output>& outputs,
528                             TF_Status* out_status,
529                             std::vector<PyObject*>* py_outputs) {
530   const std::vector<TF_Operation*> targets;
531   TF_SessionRun_wrapper_helper(session, handle,
532                                nullptr,  // run_options
533                                inputs, input_ndarrays, outputs, targets,
534                                nullptr,  // run_metadata
535                                out_status, py_outputs);
536   // Release any unused ndarray references (see memory management comment in
537   // TF_SessionRun_wrapper_helper)
538   ClearDecrefCache();
539 }
540 
GetOperationInputs(TF_Operation * oper)541 std::vector<TF_Output> GetOperationInputs(TF_Operation* oper) {
542   int num_inputs = TF_OperationNumInputs(oper);
543   std::vector<TF_Output> inputs(num_inputs);
544   TF_OperationAllInputs(oper, inputs.data(), inputs.size());
545   return inputs;
546 }
547 
TF_OperationGetControlInputs_wrapper(TF_Operation * oper)548 std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
549     TF_Operation* oper) {
550   std::vector<TF_Operation*> control_inputs(TF_OperationNumControlInputs(oper));
551   TF_OperationGetControlInputs(oper, control_inputs.data(),
552                                control_inputs.size());
553   return control_inputs;
554 }
555 
TF_OperationGetControlOutputs_wrapper(TF_Operation * oper)556 std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper(
557     TF_Operation* oper) {
558   std::vector<TF_Operation*> control_outputs(
559       TF_OperationNumControlOutputs(oper));
560   TF_OperationGetControlOutputs(oper, control_outputs.data(),
561                                 control_outputs.size());
562   return control_outputs;
563 }
564 
TF_OperationOutputConsumers_wrapper(TF_Output oper_out)565 std::vector<const char*> TF_OperationOutputConsumers_wrapper(
566     TF_Output oper_out) {
567   int num_consumers = TF_OperationOutputNumConsumers(oper_out);
568   std::vector<TF_Input> consumers(num_consumers);
569   TF_OperationOutputConsumers(oper_out, consumers.data(), num_consumers);
570 
571   std::vector<const char*> consumer_names(num_consumers);
572   for (int i = 0; i < num_consumers; ++i) {
573     consumer_names[i] = TF_OperationName(consumers[i].oper);
574   }
575   return consumer_names;
576 }
577 
TF_GraphToFunction_wrapper(const TF_Graph * fn_body,const char * fn_name,bool append_hash_to_fn_name,const std::vector<TF_Operation * > * opers,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const NameVector & output_names,const std::vector<TF_Operation * > * control_outputs,const NameVector & control_output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * out_status)578 TF_Function* TF_GraphToFunction_wrapper(
579     const TF_Graph* fn_body, const char* fn_name, bool append_hash_to_fn_name,
580     const std::vector<TF_Operation*>* opers,
581     const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs,
582     const NameVector& output_names,
583     const std::vector<TF_Operation*>* control_outputs,
584     const NameVector& control_output_names, const TF_FunctionOptions* opts,
585     const char* description, TF_Status* out_status) {
586   if (!output_names.empty() && output_names.size() != outputs.size()) {
587     Set_TF_Status_from_Status(
588         out_status,
589         errors::InvalidArgument(
590             "output names must be either empty or equal in size to outputs. ",
591             "output names size = ", output_names.size(),
592             " outputs size = ", outputs.size()));
593     return nullptr;
594   }
595 
596   int nopers = -1;
597   const TF_Operation* const* opers_array = nullptr;
598   if (opers != nullptr) {
599     nopers = opers->size();
600     opers_array = opers->data();
601   }
602 
603   const char** output_names_ptr =
604       output_names.empty() ? nullptr
605                            : const_cast<const char**>(output_names.data());
606 
607   const char** control_output_names_ptr =
608       control_output_names.empty()
609           ? nullptr
610           : const_cast<const char**>(control_output_names.data());
611 
612   return TF_GraphToFunctionWithControlOutputs(
613       fn_body, fn_name, append_hash_to_fn_name, nopers, opers_array,
614       inputs.size(), inputs.data(), outputs.size(), outputs.data(),
615       output_names_ptr,
616       control_outputs == nullptr ? 0 : control_outputs->size(),
617       control_outputs == nullptr ? nullptr : control_outputs->data(),
618       control_output_names_ptr, opts, description, out_status);
619 }
620 
TF_GraphSetOutputHandleShapesAndTypes_wrapper(TF_Graph * graph,TF_Output output,const std::vector<std::vector<int64_t>> & shapes,const std::vector<int> & ranks,const std::vector<TF_DataType> & types,TF_Status * status)621 void TF_GraphSetOutputHandleShapesAndTypes_wrapper(
622     TF_Graph* graph, TF_Output output,
623     const std::vector<std::vector<int64_t>>& shapes,
624     const std::vector<int>& ranks, const std::vector<TF_DataType>& types,
625     TF_Status* status) {
626   std::vector<const int64_t*> shapes_pointers(shapes.size());
627   for (int i = 0; i < shapes.size(); ++i) {
628     shapes_pointers[i] = ranks[i] <= 0 ? nullptr : &shapes[i][0];
629   }
630   TF_GraphSetOutputHandleShapesAndTypes(graph, output, shapes.size(),
631                                         shapes_pointers.data(), ranks.data(),
632                                         types.data(), status);
633 }
634 
CreatePlaceholder(TF_Graph * graph,TF_Status * s,string && name,TF_DataType dtype,TF_Output * output)635 void CreatePlaceholder(TF_Graph* graph, TF_Status* s, string&& name,
636                        TF_DataType dtype, TF_Output* output) {
637   TF_OperationDescription* desc =
638       TF_NewOperation(graph, "Placeholder", name.data());
639   TF_SetAttrType(desc, "dtype", dtype);
640   TF_Operation* op = TF_FinishOperation(desc, s);
641   output->oper = op;
642   output->index = 0;
643 }
644 
TF_CreatePlaceholders(TF_Graph * graph,PyObject * dtypes,const char * prefix,TF_Status * status)645 std::vector<TF_Output> TF_CreatePlaceholders(TF_Graph* graph, PyObject* dtypes,
646                                              const char* prefix,
647                                              TF_Status* status) {
648   std::vector<TF_Output> outputs;
649   dtypes = PySequence_Fast(dtypes, "dtypes must be a sequence");
650   if (dtypes == nullptr) {
651     Set_TF_Status_from_Status(status, errors::Internal("dtypes is nullptr"));
652     return outputs;
653   }
654   Safe_PyObjectPtr dtypes_holder(make_safe(dtypes));
655   Py_ssize_t len = PySequence_Fast_GET_SIZE(dtypes);
656   outputs.reserve(len);
657   for (size_t i = 0; i < len; i++) {
658     PyObject* dtype = PySequence_Fast_GET_ITEM(dtypes, i);
659     if (!dtype) {
660       Set_TF_Status_from_Status(status,
661                                 errors::Internal("Could not get dtype ", i));
662       return outputs;
663     }
664 #if PY_MAJOR_VERSION >= 3
665     TF_DataType tf_datatype = static_cast<TF_DataType>(PyLong_AsLong(dtype));
666 #else
667     TF_DataType tf_datatype = static_cast<TF_DataType>(PyInt_AsLong(dtype));
668 #endif
669     outputs.push_back(TF_Output());
670     CreatePlaceholder(graph, status, strings::StrCat(prefix, i), tf_datatype,
671                       &outputs.back());
672     if (!status->status.ok()) break;
673   }
674   return outputs;
675 }
676 
TF_GraphSetTensorShape_wrapper(TF_Graph * graph,TF_Output output,const std::vector<int64_t> & dims,bool unknown_shape,TF_Status * status)677 void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
678                                     const std::vector<int64_t>& dims,
679                                     bool unknown_shape, TF_Status* status) {
680   if (unknown_shape) {
681     TF_GraphSetTensorShape(graph, output, nullptr, -1, status);
682     return;
683   }
684   TF_GraphSetTensorShape(graph, output, dims.data(), dims.size(), status);
685 }
686 
TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(TF_ImportGraphDefResults * results)687 std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
688     TF_ImportGraphDefResults* results) {
689   int num_missing_unused_input_mappings;
690   const char** src_names;
691   int* src_indexes;
692   TF_ImportGraphDefResultsMissingUnusedInputMappings(
693       results, &num_missing_unused_input_mappings, &src_names, &src_indexes);
694   std::vector<string> input_strs(num_missing_unused_input_mappings);
695   for (int i = 0; i < num_missing_unused_input_mappings; ++i) {
696     input_strs[i] = TensorId(src_names[i], src_indexes[i]).ToString();
697   }
698   return input_strs;
699 }
700 
TF_TryEvaluateConstant_wrapper(TF_Graph * graph,TF_Output output,TF_Status * status)701 PyObject* TF_TryEvaluateConstant_wrapper(TF_Graph* graph, TF_Output output,
702                                          TF_Status* status) {
703   TF_Tensor* result_tensor;
704   bool evaluated =
705       TF_TryEvaluateConstant(graph, output, &result_tensor, status);
706   if (!evaluated || TF_GetCode(status) != TF_OK) Py_RETURN_NONE;
707 
708   Safe_TF_TensorPtr safe_result_tensor(result_tensor);
709   PyObject* out;
710   Status s = TF_TensorToPyArray(std::move(safe_result_tensor), &out);
711   Set_TF_Status_from_Status(status, s);
712   if (!s.ok()) Py_RETURN_NONE;
713   return PyArray_Return(reinterpret_cast<PyArrayObject*>(out));
714 }
715 
716 }  // namespace tensorflow
717