• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/client/tf_session_helper.h"
17 
18 #include <cstring>
19 
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/c_api_internal.h"
22 #include "tensorflow/c/tf_status_helper.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/log_memory.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/graph/tensor_id.h"
29 #include "tensorflow/core/lib/core/coding.h"
30 #include "tensorflow/core/lib/strings/stringprintf.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/equal_graph_def.h"
33 #include "tensorflow/python/client/session_ref.h"
34 #include "tensorflow/python/lib/core/ndarray_tensor.h"
35 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
36 #include "tensorflow/python/lib/core/safe_ptr.h"
37 
38 namespace tensorflow {
39 
40 namespace {
41 
42 static const char* kFeedDictErrorMsg =
43     "feed_dict must be a dictionary mapping strings to NumPy arrays.";
44 }  // end namespace
45 
TF_NewSessionRef(TF_Graph * graph,const TF_SessionOptions * opts,TF_Status * status)46 TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts,
47                              TF_Status* status) {
48   TF_Session* tf_session = TF_NewSession(graph, opts, status);
49   if (tf_session == nullptr) {
50     return nullptr;
51   }
52 
53   Session* session = reinterpret_cast<Session*>(tf_session->session);
54   SessionRef* session_ref = new SessionRef(session);
55   tf_session->session = session_ref;
56   return tf_session;
57 }
58 
TF_Run_wrapper_helper(TF_DeprecatedSession * session,const char * handle,const TF_Buffer * run_options,PyObject * feed_dict,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_outputs)59 void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
60                            const TF_Buffer* run_options, PyObject* feed_dict,
61                            const NameVector& output_names,
62                            const NameVector& target_nodes,
63                            TF_Status* out_status, PyObjectVector* out_values,
64                            TF_Buffer* run_outputs) {
65   // 1. Convert the feed inputs to the appropriate form for TF_Run.
66   if (!PyDict_Check(feed_dict)) {
67     Set_TF_Status_from_Status(out_status,
68                               errors::InvalidArgument(kFeedDictErrorMsg));
69     return;
70   }
71 
72   NameVector input_names;
73   std::vector<Safe_TF_TensorPtr> inputs_safe;  // Used to delete tensors.
74   TF_TensorVector inputs_unsafe;  // Used to contain the arg to TF_Run.
75 
76   PyObject* key;
77   PyObject* value;
78   Py_ssize_t pos = 0;
79   int index = 0;
80   Status s;
81 
82   while (PyDict_Next(feed_dict, &pos, &key, &value)) {
83     char* key_string = PyBytes_AsString(key);
84     if (!key_string) {
85       Set_TF_Status_from_Status(out_status,
86                                 errors::InvalidArgument(kFeedDictErrorMsg));
87       return;
88     }
89     input_names.push_back(key_string);
90 
91     inputs_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
92     s = PyArrayToTF_Tensor(value, &inputs_safe.back());
93     if (!s.ok()) {
94       Set_TF_Status_from_Status(out_status, s);
95       return;
96     }
97     inputs_unsafe.push_back(inputs_safe.back().get());
98     ++index;
99   }
100 
101   // 2. Allocate a container for the output data.
102   TF_TensorVector outputs(output_names.size());
103 
104   // In case any tensors were leftover from previous runs we might as well clear
105   // them here.
106   ClearDecrefCache();
107 
108   // 3. Actually call TF_Run().
109   Py_BEGIN_ALLOW_THREADS;
110   if (handle == nullptr) {
111     TF_Run(session, run_options, input_names.data(), inputs_unsafe.data(),
112            input_names.size(), const_cast<const char**>(output_names.data()),
113            outputs.data(), output_names.size(),
114            const_cast<const char**>(target_nodes.data()), target_nodes.size(),
115            run_outputs, out_status);
116   } else {
117     TF_PRun(session, handle, input_names.data(), inputs_unsafe.data(),
118             input_names.size(), const_cast<const char**>(output_names.data()),
119             outputs.data(), output_names.size(),
120             const_cast<const char**>(target_nodes.data()), target_nodes.size(),
121             out_status);
122   }
123 
124   Py_END_ALLOW_THREADS;
125 
126   // Decref any numpy arrays we are not using anymore.
127   ClearDecrefCache();
128 
129   if (TF_GetCode(out_status) != TF_OK) {
130     return;
131   }
132 
133   // 4. We now own the fetched tensors, so set up a safe container to
134   // delete them when we exit this scope.
135   std::vector<Safe_TF_TensorPtr> tf_outputs_safe;
136   for (const auto& output : outputs) {
137     tf_outputs_safe.emplace_back(make_safe(output));
138   }
139 
140   // 5. Convert the fetched tensors into numpy ndarrays. Store them in a safe
141   // container so that we do not leak
142   std::vector<Safe_PyObjectPtr> py_outputs_safe;
143   for (size_t i = 0; i < output_names.size(); ++i) {
144     PyObject* py_array;
145     s = TF_TensorToPyArray(std::move(tf_outputs_safe[i]), &py_array);
146     if (!s.ok()) {
147       Set_TF_Status_from_Status(out_status, s);
148       return;
149     }
150     py_outputs_safe.emplace_back(
151         make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array))));
152   }
153 
154   // 6. If we reach this point, we have successfully built a list of objects
155   // so we can release them from the safe container.
156   for (auto& output : py_outputs_safe) {
157     out_values->push_back(output.release());
158   }
159 }
160 
161 // Wrapper for TF_Run that converts the arguments to appropriate types.
162 // If *out_status is OK, the caller becomes the owner of the PyObjects
163 // in *out_values.
TF_Run_wrapper(TF_DeprecatedSession * session,const TF_Buffer * run_options,PyObject * feed_dict,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_outputs)164 void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options,
165                     PyObject* feed_dict, const NameVector& output_names,
166                     const NameVector& target_nodes, TF_Status* out_status,
167                     PyObjectVector* out_values, TF_Buffer* run_outputs) {
168   TF_Run_wrapper_helper(session, nullptr, run_options, feed_dict, output_names,
169                         target_nodes, out_status, out_values, run_outputs);
170   ClearDecrefCache();
171 }
172 
173 namespace {
MakeCallableHelper(tensorflow::Session * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * out_status)174 void MakeCallableHelper(tensorflow::Session* session,
175                         const TF_Buffer* callable_options, int64_t* out_handle,
176                         TF_Status* out_status) {
177   tensorflow::CallableOptions callable_options_proto;
178   if (callable_options != nullptr &&
179       !callable_options_proto.ParseFromArray(callable_options->data,
180                                              callable_options->length)) {
181     Set_TF_Status_from_Status(
182         out_status,
183         errors::InvalidArgument("Unparseable CallableOptions proto"));
184     return;
185   }
186   tensorflow::Session::CallableHandle handle;
187   Status s = session->MakeCallable(callable_options_proto, &handle);
188   if (!s.ok()) {
189     Set_TF_Status_from_Status(out_status, s);
190     return;
191   }
192   *out_handle = handle;
193 }
194 }  // namespace
195 
TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * status)196 void TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession* session,
197                                       const TF_Buffer* callable_options,
198                                       int64_t* out_handle, TF_Status* status) {
199   MakeCallableHelper(session->session, callable_options, out_handle, status);
200 }
TF_SessionMakeCallable(TF_Session * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * status)201 void TF_SessionMakeCallable(TF_Session* session,
202                             const TF_Buffer* callable_options,
203                             int64_t* out_handle, TF_Status* status) {
204   MakeCallableHelper(session->session, callable_options, out_handle, status);
205 }
206 
207 namespace {
RunCallableHelper(tensorflow::Session * session,int64_t handle,PyObject * feed_values,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_metadata)208 void RunCallableHelper(tensorflow::Session* session, int64_t handle,
209                        PyObject* feed_values, TF_Status* out_status,
210                        PyObjectVector* out_values, TF_Buffer* run_metadata) {
211   // Convert feed values to a vector of tensorflow::Tensor objects.
212   std::vector<Tensor> input_tensors;
213   Status s;
214   {
215     feed_values =
216         PySequence_Fast(feed_values, "feed_values must be a sequence");
217     if (feed_values == nullptr) return;
218     Safe_PyObjectPtr feed_values_holder(make_safe(feed_values));
219     Py_ssize_t len = PySequence_Fast_GET_SIZE(feed_values);
220     input_tensors.reserve(len);
221     for (Py_ssize_t i = 0; i < len; ++i) {
222       PyObject* elem = PySequence_Fast_GET_ITEM(feed_values, i);
223       if (!elem) {
224         Set_TF_Status_from_Status(
225             out_status, errors::Internal("Could not get feed value ", i));
226         return;
227       }
228       Tensor t;
229       s = NdarrayToTensor(elem, &t);
230       if (!s.ok()) {
231         Set_TF_Status_from_Status(out_status, s);
232         return;
233       }
234       input_tensors.push_back(std::move(t));
235     }
236   }
237 
238   // Allocate a RunMetadata protobuf object to receive the metadata,
239   // if the caller is expecting any.
240   std::unique_ptr<RunMetadata> run_metadata_proto;
241   if (run_metadata != nullptr) {
242     run_metadata_proto.reset(new RunMetadata);
243   }
244 
245   // Run the callable.
246   std::vector<Tensor> output_tensors;
247   Py_BEGIN_ALLOW_THREADS;
248   s = session->RunCallable(handle, input_tensors, &output_tensors,
249                            run_metadata_proto.get());
250   Py_END_ALLOW_THREADS;
251 
252   if (!s.ok()) {
253     Set_TF_Status_from_Status(out_status, s);
254     return;
255   }
256 
257   // If requested, serialize the RunMetadata to pass it back to the caller.
258   if (run_metadata != nullptr) {
259     s = MessageToBuffer(*run_metadata_proto, run_metadata);
260     if (!s.ok()) {
261       Set_TF_Status_from_Status(out_status, s);
262       return;
263     }
264   }
265 
266   // Convert results to NumPy arrays. Since this can fail, stage the
267   // results via a safe container that takes care of decreasing the
268   // reference count on failure.
269   std::vector<Safe_PyObjectPtr> py_outputs_safe;
270   py_outputs_safe.reserve(output_tensors.size());
271   for (const Tensor& output : output_tensors) {
272     PyObject* py_array;
273     s = TensorToNdarray(output, &py_array);
274     if (!s.ok()) {
275       Set_TF_Status_from_Status(out_status, s);
276       return;
277     }
278     py_outputs_safe.push_back(
279         make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array))));
280   }
281 
282   // If we reach this point, we have successfully built a list of objects
283   // so we can release them from the safe container.
284   out_values->reserve(py_outputs_safe.size());
285   for (auto& output : py_outputs_safe) {
286     out_values->push_back(output.release());
287   }
288 }
289 }  // namespace
290 
TF_DeprecatedSessionRunCallable(TF_DeprecatedSession * session,int64_t handle,PyObject * feed_values,PyObjectVector * out_values,TF_Buffer * run_metadata,TF_Status * status)291 void TF_DeprecatedSessionRunCallable(TF_DeprecatedSession* session,
292                                      int64_t handle, PyObject* feed_values,
293                                      PyObjectVector* out_values,
294                                      TF_Buffer* run_metadata,
295                                      TF_Status* status) {
296   RunCallableHelper(session->session, handle, feed_values, status, out_values,
297                     run_metadata);
298   ClearDecrefCache();
299 }
TF_SessionRunCallable(TF_Session * session,int64_t handle,PyObject * feed_values,PyObjectVector * out_values,TF_Buffer * run_metadata,TF_Status * status)300 void TF_SessionRunCallable(TF_Session* session, int64_t handle,
301                            PyObject* feed_values, PyObjectVector* out_values,
302                            TF_Buffer* run_metadata, TF_Status* status) {
303   RunCallableHelper(session->session, handle, feed_values, status, out_values,
304                     run_metadata);
305   ClearDecrefCache();
306 }
307 
TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession * session,int64_t handle,TF_Status * status)308 void TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession* session,
309                                          int64_t handle, TF_Status* status) {
310   Set_TF_Status_from_Status(status, session->session->ReleaseCallable(handle));
311 }
TF_SessionReleaseCallable(TF_Session * session,int64_t handle,TF_Status * status)312 void TF_SessionReleaseCallable(TF_Session* session, int64_t handle,
313                                TF_Status* status) {
314   Set_TF_Status_from_Status(status, session->session->ReleaseCallable(handle));
315 }
316 
317 // Wrapper for TF_PRunSetup that converts the arguments to appropriate types.
318 // If *out_status is OK, the caller becomes the owner of *out_handle.
TF_PRunSetup_wrapper(TF_DeprecatedSession * session,const NameVector & input_names,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,const char ** out_handle)319 void TF_PRunSetup_wrapper(TF_DeprecatedSession* session,
320                           const NameVector& input_names,
321                           const NameVector& output_names,
322                           const NameVector& target_nodes, TF_Status* out_status,
323                           const char** out_handle) {
324   Py_BEGIN_ALLOW_THREADS;
325   TF_PRunSetup(
326       session, const_cast<const char**>(input_names.data()), input_names.size(),
327       const_cast<const char**>(output_names.data()), output_names.size(),
328       const_cast<const char**>(target_nodes.data()), target_nodes.size(),
329       out_handle, out_status);
330   Py_END_ALLOW_THREADS;
331 }
332 
333 // Wrapper for TF_PRun that converts the arguments to appropriate types.
334 // If *out_status is OK, the caller becomes the owner of the PyObjects
335 // in *out_values.
TF_PRun_wrapper(TF_DeprecatedSession * session,const char * handle,PyObject * feed_dict,const NameVector & output_names,TF_Status * out_status,PyObjectVector * out_values)336 void TF_PRun_wrapper(TF_DeprecatedSession* session, const char* handle,
337                      PyObject* feed_dict, const NameVector& output_names,
338                      TF_Status* out_status, PyObjectVector* out_values) {
339   TF_Run_wrapper_helper(session, handle, nullptr, feed_dict, output_names,
340                         NameVector(), out_status, out_values, nullptr);
341   ClearDecrefCache();
342 }
343 
344 // Wrapper for TF_Reset that converts the string vectors to character arrays.
TF_Reset_wrapper(const TF_SessionOptions * opt,const NameVector & containers,TF_Status * status)345 void TF_Reset_wrapper(const TF_SessionOptions* opt,
346                       const NameVector& containers, TF_Status* status) {
347   TF_Reset(opt, const_cast<const char**>(containers.data()), containers.size(),
348            status);
349 }
350 
TF_SessionRun_wrapper_helper(TF_Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,TF_Buffer * run_metadata,TF_Status * out_status,std::vector<PyObject * > * py_outputs)351 void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
352                                   const TF_Buffer* run_options,
353                                   const std::vector<TF_Output>& inputs,
354                                   const std::vector<PyObject*>& input_ndarrays,
355                                   const std::vector<TF_Output>& outputs,
356                                   const std::vector<TF_Operation*>& targets,
357                                   TF_Buffer* run_metadata,
358                                   TF_Status* out_status,
359                                   std::vector<PyObject*>* py_outputs) {
360   DCHECK_EQ(inputs.size(), input_ndarrays.size());
361   DCHECK(py_outputs != nullptr);
362   DCHECK(py_outputs->empty());
363   Status s;
364 
365   // Convert input ndarray PyObjects to TF_Tensors. We maintain a continuous
366   // array of TF_Tensor*s as well as scoped containers to make sure they're
367   // cleaned up properly.
368   //
369   // Memory management:
370   // PyArrayToTF_Tensor() creates a new ndarray PyObject from the input
371   // ndarray. We manage the new ndarray's lifetime in order to keep the
372   // underlying data buffer alive (the new ndarray also guarantees a contiguous
373   // data buffer). The new ndarray's data buffer is used to create the
374   // corresponding TF_Tensor. The TF_Tensor's deallocator will queue the new
375   // ndarray to be decref'd by the next ClearDecrefCache() call (we can't call
376   // Py_DECREF in the deallocator directly because the GIL must be held).
377   //
378   // Note that TF_Tensor may directly delegate its data and deallocator to a
379   // TensorBuffer, which may outlive the TF_Tensor (e.g. if the tensor gets
380   // queued or assigned to a variable).
381   TF_TensorVector input_vals;
382   std::vector<Safe_TF_TensorPtr> input_vals_safe;
383   for (PyObject* ndarray : input_ndarrays) {
384     input_vals_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
385     s = PyArrayToTF_Tensor(ndarray, &input_vals_safe.back());
386     if (!s.ok()) {
387       Set_TF_Status_from_Status(out_status, s);
388       return;
389     }
390     input_vals.push_back(input_vals_safe.back().get());
391   }
392 
393   // Allocate space for output TF_Tensor*s
394   TF_TensorVector output_vals(outputs.size());
395 
396   // Clear up any unused memory leftover from previous runs
397   ClearDecrefCache();
398 
399   // Call TF_SessionRun() (and release GIL during execution)
400   Py_BEGIN_ALLOW_THREADS;
401   if (handle == nullptr) {
402     TF_SessionRun(session, run_options, inputs.data(), input_vals.data(),
403                   inputs.size(), outputs.data(), output_vals.data(),
404                   outputs.size(), targets.data(), targets.size(), run_metadata,
405                   out_status);
406   } else {
407     TF_SessionPRun(session, handle, inputs.data(), input_vals.data(),
408                    inputs.size(), outputs.data(), output_vals.data(),
409                    outputs.size(), targets.data(), targets.size(), out_status);
410   }
411   Py_END_ALLOW_THREADS;
412 
413   // Create scoped containers for output tensors
414   std::vector<Safe_TF_TensorPtr> output_vals_safe;
415   for (TF_Tensor* output : output_vals) {
416     output_vals_safe.emplace_back(make_safe(output));
417   }
418 
419   // Convert outputs to ndarrays (in scoped containers)
420   std::vector<Safe_PyObjectPtr> py_outputs_safe;
421   for (size_t i = 0; i < outputs.size(); ++i) {
422     PyObject* py_array;
423     s = TF_TensorToPyArray(std::move(output_vals_safe[i]), &py_array);
424     if (!s.ok()) {
425       Set_TF_Status_from_Status(out_status, s);
426       return;
427     }
428     py_outputs_safe.emplace_back(
429         make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array))));
430   }
431 
432   // If we reach this point, we have successfully built a list of objects so we
433   // can release them from the safe container into the return vector.
434   for (size_t i = 0; i < outputs.size(); ++i) {
435     py_outputs->push_back(py_outputs_safe[i].release());
436   }
437 }
438 
TF_SessionRun_wrapper(TF_Session * session,const TF_Buffer * run_options,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,TF_Buffer * run_metadata,TF_Status * out_status,std::vector<PyObject * > * py_outputs)439 void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options,
440                            const std::vector<TF_Output>& inputs,
441                            const std::vector<PyObject*>& input_ndarrays,
442                            const std::vector<TF_Output>& outputs,
443                            const std::vector<TF_Operation*>& targets,
444                            TF_Buffer* run_metadata, TF_Status* out_status,
445                            std::vector<PyObject*>* py_outputs) {
446   TF_SessionRun_wrapper_helper(session, nullptr, run_options, inputs,
447                                input_ndarrays, outputs, targets, run_metadata,
448                                out_status, py_outputs);
449   // Release any unused ndarray references (see memory management comment in
450   // TF_SessionRun_wrapper_helper)
451   ClearDecrefCache();
452 }
453 
EqualGraphDefWrapper(const string & actual,const string & expected)454 string EqualGraphDefWrapper(const string& actual, const string& expected) {
455   GraphDef actual_def;
456   if (!actual_def.ParseFromString(actual)) {
457     return "actual is not a valid serialized GraphDef";
458   }
459   GraphDef expected_def;
460   if (!expected_def.ParseFromString(expected)) {
461     return "expected is not a valid serialized GraphDef";
462   }
463   string diff;
464   return EqualGraphDef(actual_def, expected_def, &diff) ? "" : diff;
465 }
466 
EqualAttrValueWrapper(const string & actual,const string & expected)467 string EqualAttrValueWrapper(const string& actual, const string& expected) {
468   AttrValue actual_attr_value;
469   if (!actual_attr_value.ParseFromString(actual)) {
470     return "actual is not a valid serialized AttrValue";
471   }
472 
473   AttrValue expected_attr_value;
474   if (!expected_attr_value.ParseFromString(expected)) {
475     return "expected is not a valid serialized AttrValue";
476   }
477 
478   string diff;
479   if (!AreAttrValuesEqual(actual_attr_value, expected_attr_value)) {
480     diff = strings::Printf(
481         "Actual AttrValue %s does not match Expected AttrValue %s.",
482         SummarizeAttrValue(actual_attr_value).c_str(),
483         SummarizeAttrValue(expected_attr_value).c_str());
484   }
485   return diff;
486 }
487 
488 // Return value set to 6 inlined elements so it fits in a 64-byte cache line.
TF_GraphGetTensorShapeHelper(TF_Graph * graph,TF_Output output,TF_Status * out_status,bool * unknown_shape)489 tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper(
490     TF_Graph* graph, TF_Output output, TF_Status* out_status,
491     bool* unknown_shape) {
492   // Allocate a single variable for holding the result for RVO.
493   tensorflow::gtl::InlinedVector<int64_t, 6> result;
494   *unknown_shape = false;
495   int num_dims = TF_GraphGetTensorNumDims(graph, output, out_status);
496   if (TF_GetCode(out_status) != TF_OK) {
497     return result;
498   }
499   // If shape is unknown, set boolean and return.
500   if (num_dims == -1) {
501     *unknown_shape = true;
502     return result;
503   }
504 
505   // If shape is a scalar, avoid another C call and just return {}.
506   if (num_dims == 0) {
507     return result;
508   }
509 
510   result.resize(num_dims);
511   TF_GraphGetTensorShape(graph, output, result.data(), num_dims, out_status);
512   return result;
513 }
514 
TF_SessionPRunSetup_wrapper(TF_Session * session,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,const char ** out_handle,TF_Status * out_status)515 void TF_SessionPRunSetup_wrapper(TF_Session* session,
516                                  const std::vector<TF_Output>& inputs,
517                                  const std::vector<TF_Output>& outputs,
518                                  const std::vector<TF_Operation*>& targets,
519                                  const char** out_handle,
520                                  TF_Status* out_status) {
521   // Call TF_SessionPRunSetup() (and release GIL during execution)
522   Py_BEGIN_ALLOW_THREADS;
523   TF_SessionPRunSetup(session, inputs.data(), inputs.size(), outputs.data(),
524                       outputs.size(), targets.data(), targets.size(),
525                       out_handle, out_status);
526   Py_END_ALLOW_THREADS;
527 }
528 
TF_SessionPRun_wrapper(TF_Session * session,const char * handle,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,TF_Status * out_status,std::vector<PyObject * > * py_outputs)529 void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
530                             const std::vector<TF_Output>& inputs,
531                             const std::vector<PyObject*>& input_ndarrays,
532                             const std::vector<TF_Output>& outputs,
533                             TF_Status* out_status,
534                             std::vector<PyObject*>* py_outputs) {
535   const std::vector<TF_Operation*> targets;
536   TF_SessionRun_wrapper_helper(session, handle,
537                                nullptr,  // run_options
538                                inputs, input_ndarrays, outputs, targets,
539                                nullptr,  // run_metadata
540                                out_status, py_outputs);
541   // Release any unused ndarray references (see memory management comment in
542   // TF_SessionRun_wrapper_helper)
543   ClearDecrefCache();
544 }
545 
GetOperationInputs(TF_Operation * oper)546 std::vector<TF_Output> GetOperationInputs(TF_Operation* oper) {
547   int num_inputs = TF_OperationNumInputs(oper);
548   std::vector<TF_Output> inputs(num_inputs);
549   TF_OperationAllInputs(oper, inputs.data(), inputs.size());
550   return inputs;
551 }
552 
TF_OperationGetControlInputs_wrapper(TF_Operation * oper)553 std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
554     TF_Operation* oper) {
555   std::vector<TF_Operation*> control_inputs(TF_OperationNumControlInputs(oper));
556   TF_OperationGetControlInputs(oper, control_inputs.data(),
557                                control_inputs.size());
558   return control_inputs;
559 }
560 
TF_OperationGetControlOutputs_wrapper(TF_Operation * oper)561 std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper(
562     TF_Operation* oper) {
563   std::vector<TF_Operation*> control_outputs(
564       TF_OperationNumControlOutputs(oper));
565   TF_OperationGetControlOutputs(oper, control_outputs.data(),
566                                 control_outputs.size());
567   return control_outputs;
568 }
569 
TF_OperationOutputConsumers_wrapper(TF_Output oper_out)570 std::vector<const char*> TF_OperationOutputConsumers_wrapper(
571     TF_Output oper_out) {
572   int num_consumers = TF_OperationOutputNumConsumers(oper_out);
573   std::vector<TF_Input> consumers(num_consumers);
574   TF_OperationOutputConsumers(oper_out, consumers.data(), num_consumers);
575 
576   std::vector<const char*> consumer_names(num_consumers);
577   for (int i = 0; i < num_consumers; ++i) {
578     consumer_names[i] = TF_OperationName(consumers[i].oper);
579   }
580   return consumer_names;
581 }
582 
TF_GraphToFunction_wrapper(const TF_Graph * fn_body,const char * fn_name,bool append_hash_to_fn_name,const std::vector<TF_Operation * > * opers,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const NameVector & output_names,const std::vector<TF_Operation * > * control_outputs,const NameVector & control_output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * out_status)583 TF_Function* TF_GraphToFunction_wrapper(
584     const TF_Graph* fn_body, const char* fn_name, bool append_hash_to_fn_name,
585     const std::vector<TF_Operation*>* opers,
586     const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs,
587     const NameVector& output_names,
588     const std::vector<TF_Operation*>* control_outputs,
589     const NameVector& control_output_names, const TF_FunctionOptions* opts,
590     const char* description, TF_Status* out_status) {
591   if (!output_names.empty() && output_names.size() != outputs.size()) {
592     Set_TF_Status_from_Status(
593         out_status,
594         errors::InvalidArgument(
595             "output names must be either empty or equal in size to outputs. ",
596             "output names size = ", output_names.size(),
597             " outputs size = ", outputs.size()));
598     return nullptr;
599   }
600 
601   int nopers = -1;
602   const TF_Operation* const* opers_array = nullptr;
603   if (opers != nullptr) {
604     nopers = opers->size();
605     opers_array = opers->data();
606   }
607 
608   const char** output_names_ptr =
609       output_names.empty() ? nullptr
610                            : const_cast<const char**>(output_names.data());
611 
612   const char** control_output_names_ptr =
613       control_output_names.empty()
614           ? nullptr
615           : const_cast<const char**>(control_output_names.data());
616 
617   return TF_GraphToFunctionWithControlOutputs(
618       fn_body, fn_name, append_hash_to_fn_name, nopers, opers_array,
619       inputs.size(), inputs.data(), outputs.size(), outputs.data(),
620       output_names_ptr,
621       control_outputs == nullptr ? 0 : control_outputs->size(),
622       control_outputs == nullptr ? nullptr : control_outputs->data(),
623       control_output_names_ptr, opts, description, out_status);
624 }
625 
TF_GraphSetOutputHandleShapesAndTypes_wrapper(TF_Graph * graph,TF_Output output,const std::vector<std::vector<int64_t>> & shapes,const std::vector<int> & ranks,const std::vector<TF_DataType> & types,TF_Status * status)626 void TF_GraphSetOutputHandleShapesAndTypes_wrapper(
627     TF_Graph* graph, TF_Output output,
628     const std::vector<std::vector<int64_t>>& shapes,
629     const std::vector<int>& ranks, const std::vector<TF_DataType>& types,
630     TF_Status* status) {
631   std::vector<const int64_t*> shapes_pointers(shapes.size());
632   for (int i = 0; i < shapes.size(); ++i) {
633     shapes_pointers[i] = ranks[i] <= 0 ? nullptr : &shapes[i][0];
634   }
635   TF_GraphSetOutputHandleShapesAndTypes(graph, output, shapes.size(),
636                                         shapes_pointers.data(), ranks.data(),
637                                         types.data(), status);
638 }
639 
CreatePlaceholder(TF_Graph * graph,TF_Status * s,string && name,TF_DataType dtype,TF_Output * output)640 void CreatePlaceholder(TF_Graph* graph, TF_Status* s, string&& name,
641                        TF_DataType dtype, TF_Output* output) {
642   TF_OperationDescription* desc =
643       TF_NewOperation(graph, "Placeholder", name.data());
644   TF_SetAttrType(desc, "dtype", dtype);
645   TF_Operation* op = TF_FinishOperation(desc, s);
646   output->oper = op;
647   output->index = 0;
648 }
649 
TF_CreatePlaceholders(TF_Graph * graph,PyObject * dtypes,const char * prefix,TF_Status * status)650 std::vector<TF_Output> TF_CreatePlaceholders(TF_Graph* graph, PyObject* dtypes,
651                                              const char* prefix,
652                                              TF_Status* status) {
653   std::vector<TF_Output> outputs;
654   dtypes = PySequence_Fast(dtypes, "dtypes must be a sequence");
655   if (dtypes == nullptr) {
656     Set_TF_Status_from_Status(status, errors::Internal("dtypes is nullptr"));
657     return outputs;
658   }
659   Safe_PyObjectPtr dtypes_holder(make_safe(dtypes));
660   Py_ssize_t len = PySequence_Fast_GET_SIZE(dtypes);
661   outputs.reserve(len);
662   for (size_t i = 0; i < len; i++) {
663     PyObject* dtype = PySequence_Fast_GET_ITEM(dtypes, i);
664     if (!dtype) {
665       Set_TF_Status_from_Status(status,
666                                 errors::Internal("Could not get dtype ", i));
667       return outputs;
668     }
669 #if PY_MAJOR_VERSION >= 3
670     TF_DataType tf_datatype = static_cast<TF_DataType>(PyLong_AsLong(dtype));
671 #else
672     TF_DataType tf_datatype = static_cast<TF_DataType>(PyInt_AsLong(dtype));
673 #endif
674     outputs.push_back(TF_Output());
675     CreatePlaceholder(graph, status, strings::StrCat(prefix, i), tf_datatype,
676                       &outputs.back());
677     if (!status->status.ok()) break;
678   }
679   return outputs;
680 }
681 
TF_GraphSetTensorShape_wrapper(TF_Graph * graph,TF_Output output,const std::vector<int64_t> & dims,bool unknown_shape,TF_Status * status)682 void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
683                                     const std::vector<int64_t>& dims,
684                                     bool unknown_shape, TF_Status* status) {
685   if (unknown_shape) {
686     TF_GraphSetTensorShape(graph, output, nullptr, -1, status);
687     return;
688   }
689   TF_GraphSetTensorShape(graph, output, dims.data(), dims.size(), status);
690 }
691 
TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(TF_ImportGraphDefResults * results)692 std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
693     TF_ImportGraphDefResults* results) {
694   int num_missing_unused_input_mappings;
695   const char** src_names;
696   int* src_indexes;
697   TF_ImportGraphDefResultsMissingUnusedInputMappings(
698       results, &num_missing_unused_input_mappings, &src_names, &src_indexes);
699   std::vector<string> input_strs(num_missing_unused_input_mappings);
700   for (int i = 0; i < num_missing_unused_input_mappings; ++i) {
701     input_strs[i] = TensorId(src_names[i], src_indexes[i]).ToString();
702   }
703   return input_strs;
704 }
705 
TF_TryEvaluateConstant_wrapper(TF_Graph * graph,TF_Output output,TF_Status * status)706 PyObject* TF_TryEvaluateConstant_wrapper(TF_Graph* graph, TF_Output output,
707                                          TF_Status* status) {
708   TF_Tensor* result_tensor;
709   bool evaluated =
710       TF_TryEvaluateConstant(graph, output, &result_tensor, status);
711   if (!evaluated || TF_GetCode(status) != TF_OK) Py_RETURN_NONE;
712 
713   Safe_TF_TensorPtr safe_result_tensor(result_tensor);
714   PyObject* out;
715   Status s = TF_TensorToPyArray(std::move(safe_result_tensor), &out);
716   Set_TF_Status_from_Status(status, s);
717   if (!s.ok()) Py_RETURN_NONE;
718   return PyArray_Return(reinterpret_cast<PyArrayObject*>(out));
719 }
720 
721 }  // namespace tensorflow
722