• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/client/tf_session_helper.h"
17 
18 #include <cstring>
19 
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/c_api_internal.h"
22 #include "tensorflow/c/tf_buffer_internal.h"
23 #include "tensorflow/c/tf_status_helper.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/attr_value_util.h"
27 #include "tensorflow/core/framework/log_memory.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/graph/tensor_id.h"
30 #include "tensorflow/core/lib/core/coding.h"
31 #include "tensorflow/core/lib/strings/stringprintf.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/equal_graph_def.h"
34 #include "tensorflow/python/client/session_ref.h"
35 #include "tensorflow/python/lib/core/ndarray_tensor.h"
36 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
37 #include "tensorflow/python/lib/core/safe_ptr.h"
38 
39 namespace tensorflow {
40 
41 namespace {
42 
43 static const char* kFeedDictErrorMsg =
44     "feed_dict must be a dictionary mapping strings to NumPy arrays.";
45 }  // end namespace
46 
TF_NewSessionRef(TF_Graph * graph,const TF_SessionOptions * opts,TF_Status * status)47 TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts,
48                              TF_Status* status) {
49   TF_Session* tf_session = TF_NewSession(graph, opts, status);
50   if (tf_session == nullptr) {
51     return nullptr;
52   }
53 
54   Session* session = reinterpret_cast<Session*>(tf_session->session);
55   SessionRef* session_ref = new SessionRef(session);
56   tf_session->session = session_ref;
57   return tf_session;
58 }
59 
TF_Run_wrapper_helper(TF_DeprecatedSession * session,const char * handle,const TF_Buffer * run_options,PyObject * feed_dict,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_outputs)60 void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
61                            const TF_Buffer* run_options, PyObject* feed_dict,
62                            const NameVector& output_names,
63                            const NameVector& target_nodes,
64                            TF_Status* out_status, PyObjectVector* out_values,
65                            TF_Buffer* run_outputs) {
66   // 1. Convert the feed inputs to the appropriate form for TF_Run.
67   if (!PyDict_Check(feed_dict)) {
68     Set_TF_Status_from_Status(out_status,
69                               errors::InvalidArgument(kFeedDictErrorMsg));
70     return;
71   }
72 
73   NameVector input_names;
74   std::vector<Safe_TF_TensorPtr> inputs_safe;  // Used to delete tensors.
75   TF_TensorVector inputs_unsafe;  // Used to contain the arg to TF_Run.
76 
77   PyObject* key;
78   PyObject* value;
79   Py_ssize_t pos = 0;
80   int index = 0;
81   Status s;
82 
83   while (PyDict_Next(feed_dict, &pos, &key, &value)) {
84     char* key_string = PyBytes_AsString(key);
85     if (!key_string) {
86       Set_TF_Status_from_Status(out_status,
87                                 errors::InvalidArgument(kFeedDictErrorMsg));
88       return;
89     }
90     input_names.push_back(key_string);
91 
92     inputs_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
93     s = NdarrayToTensor(nullptr /*ctx*/, value, &inputs_safe.back());
94     if (!s.ok()) {
95       Set_TF_Status_from_Status(out_status, s);
96       return;
97     }
98     inputs_unsafe.push_back(inputs_safe.back().get());
99     ++index;
100   }
101 
102   // 2. Allocate a container for the output data.
103   TF_TensorVector outputs(output_names.size());
104 
105   // In case any tensors were leftover from previous runs we might as well clear
106   // them here.
107   ClearDecrefCache();
108 
109   // 3. Actually call TF_Run().
110   Py_BEGIN_ALLOW_THREADS;
111   if (handle == nullptr) {
112     TF_Run(session, run_options, input_names.data(), inputs_unsafe.data(),
113            input_names.size(), const_cast<const char**>(output_names.data()),
114            outputs.data(), output_names.size(),
115            const_cast<const char**>(target_nodes.data()), target_nodes.size(),
116            run_outputs, out_status);
117   } else {
118     TF_PRun(session, handle, input_names.data(), inputs_unsafe.data(),
119             input_names.size(), const_cast<const char**>(output_names.data()),
120             outputs.data(), output_names.size(),
121             const_cast<const char**>(target_nodes.data()), target_nodes.size(),
122             out_status);
123   }
124 
125   Py_END_ALLOW_THREADS;
126 
127   // Decref any numpy arrays we are not using anymore.
128   ClearDecrefCache();
129 
130   if (TF_GetCode(out_status) != TF_OK) {
131     return;
132   }
133 
134   // 4. We now own the fetched tensors, so set up a safe container to
135   // delete them when we exit this scope.
136   std::vector<Safe_TF_TensorPtr> tf_outputs_safe;
137   for (const auto& output : outputs) {
138     tf_outputs_safe.emplace_back(make_safe(output));
139   }
140 
141   // 5. Convert the fetched tensors into numpy ndarrays. Store them in a safe
142   // container so that we do not leak
143   std::vector<Safe_PyObjectPtr> py_outputs_safe;
144   for (size_t i = 0; i < output_names.size(); ++i) {
145     PyObject* py_array;
146     s = TF_TensorToPyArray(std::move(tf_outputs_safe[i]), &py_array);
147     if (!s.ok()) {
148       Set_TF_Status_from_Status(out_status, s);
149       return;
150     }
151     py_outputs_safe.emplace_back(
152         make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array))));
153   }
154 
155   // 6. If we reach this point, we have successfully built a list of objects
156   // so we can release them from the safe container.
157   for (auto& output : py_outputs_safe) {
158     out_values->push_back(output.release());
159   }
160 }
161 
162 // Wrapper for TF_Run that converts the arguments to appropriate types.
163 // If *out_status is OK, the caller becomes the owner of the PyObjects
164 // in *out_values.
TF_Run_wrapper(TF_DeprecatedSession * session,const TF_Buffer * run_options,PyObject * feed_dict,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_outputs)165 void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options,
166                     PyObject* feed_dict, const NameVector& output_names,
167                     const NameVector& target_nodes, TF_Status* out_status,
168                     PyObjectVector* out_values, TF_Buffer* run_outputs) {
169   TF_Run_wrapper_helper(session, nullptr, run_options, feed_dict, output_names,
170                         target_nodes, out_status, out_values, run_outputs);
171   ClearDecrefCache();
172 }
173 
174 namespace {
MakeCallableHelper(tensorflow::Session * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * out_status)175 void MakeCallableHelper(tensorflow::Session* session,
176                         const TF_Buffer* callable_options, int64_t* out_handle,
177                         TF_Status* out_status) {
178   tensorflow::CallableOptions callable_options_proto;
179   if (callable_options != nullptr &&
180       !callable_options_proto.ParseFromArray(callable_options->data,
181                                              callable_options->length)) {
182     Set_TF_Status_from_Status(
183         out_status,
184         errors::InvalidArgument("Unparseable CallableOptions proto"));
185     return;
186   }
187   tensorflow::Session::CallableHandle handle;
188   Status s = session->MakeCallable(callable_options_proto, &handle);
189   if (!s.ok()) {
190     Set_TF_Status_from_Status(out_status, s);
191     return;
192   }
193   *out_handle = handle;
194 }
195 }  // namespace
196 
TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * status)197 void TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession* session,
198                                       const TF_Buffer* callable_options,
199                                       int64_t* out_handle, TF_Status* status) {
200   MakeCallableHelper(session->session, callable_options, out_handle, status);
201 }
TF_SessionMakeCallable(TF_Session * session,const TF_Buffer * callable_options,int64_t * out_handle,TF_Status * status)202 void TF_SessionMakeCallable(TF_Session* session,
203                             const TF_Buffer* callable_options,
204                             int64_t* out_handle, TF_Status* status) {
205   MakeCallableHelper(session->session, callable_options, out_handle, status);
206 }
207 
208 namespace {
RunCallableHelper(tensorflow::Session * session,int64_t handle,PyObject * feed_values,TF_Status * out_status,PyObjectVector * out_values,TF_Buffer * run_metadata)209 void RunCallableHelper(tensorflow::Session* session, int64_t handle,
210                        PyObject* feed_values, TF_Status* out_status,
211                        PyObjectVector* out_values, TF_Buffer* run_metadata) {
212   // Convert feed values to a vector of tensorflow::Tensor objects.
213   std::vector<Tensor> input_tensors;
214   Status s;
215   {
216     feed_values =
217         PySequence_Fast(feed_values, "feed_values must be a sequence");
218     if (feed_values == nullptr) return;
219     Safe_PyObjectPtr feed_values_holder(make_safe(feed_values));
220     Py_ssize_t len = PySequence_Fast_GET_SIZE(feed_values);
221     input_tensors.reserve(len);
222     for (Py_ssize_t i = 0; i < len; ++i) {
223       PyObject* elem = PySequence_Fast_GET_ITEM(feed_values, i);
224       if (!elem) {
225         Set_TF_Status_from_Status(
226             out_status, errors::Internal("Could not get feed value ", i));
227         return;
228       }
229       Tensor t;
230       s = NdarrayToTensor(elem, &t);
231       if (!s.ok()) {
232         Set_TF_Status_from_Status(out_status, s);
233         return;
234       }
235       input_tensors.push_back(std::move(t));
236     }
237   }
238 
239   RunMetadata run_metadata_proto;
240 
241   // Run the callable.
242   std::vector<Tensor> output_tensors;
243   Py_BEGIN_ALLOW_THREADS;
244   s = session->RunCallable(handle, input_tensors, &output_tensors,
245                            &run_metadata_proto);
246   Py_END_ALLOW_THREADS;
247 
248   if (!s.ok()) {
249     Set_TF_Status_from_Status(out_status, s);
250     return;
251   }
252 
253   // If requested, serialize the RunMetadata to pass it back to the caller.
254   if (run_metadata != nullptr) {
255     s = MessageToBuffer(run_metadata_proto, run_metadata);
256     if (!s.ok()) {
257       Set_TF_Status_from_Status(out_status, s);
258       return;
259     }
260   }
261 
262   // Convert results to NumPy arrays. Since this can fail, stage the
263   // results via a safe container that takes care of decreasing the
264   // reference count on failure.
265   std::vector<Safe_PyObjectPtr> py_outputs_safe;
266   py_outputs_safe.reserve(output_tensors.size());
267   for (const Tensor& output : output_tensors) {
268     PyObject* py_array;
269     s = TensorToNdarray(output, &py_array);
270     if (!s.ok()) {
271       Set_TF_Status_from_Status(out_status, s);
272       return;
273     }
274     py_outputs_safe.push_back(
275         make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array))));
276   }
277 
278   // If we reach this point, we have successfully built a list of objects
279   // so we can release them from the safe container.
280   out_values->reserve(py_outputs_safe.size());
281   for (auto& output : py_outputs_safe) {
282     out_values->push_back(output.release());
283   }
284 }
285 }  // namespace
286 
TF_DeprecatedSessionRunCallable(TF_DeprecatedSession * session,int64_t handle,PyObject * feed_values,PyObjectVector * out_values,TF_Buffer * run_metadata,TF_Status * status)287 void TF_DeprecatedSessionRunCallable(TF_DeprecatedSession* session,
288                                      int64_t handle, PyObject* feed_values,
289                                      PyObjectVector* out_values,
290                                      TF_Buffer* run_metadata,
291                                      TF_Status* status) {
292   RunCallableHelper(session->session, handle, feed_values, status, out_values,
293                     run_metadata);
294   ClearDecrefCache();
295 }
TF_SessionRunCallable(TF_Session * session,int64_t handle,PyObject * feed_values,PyObjectVector * out_values,TF_Buffer * run_metadata,TF_Status * status)296 void TF_SessionRunCallable(TF_Session* session, int64_t handle,
297                            PyObject* feed_values, PyObjectVector* out_values,
298                            TF_Buffer* run_metadata, TF_Status* status) {
299   RunCallableHelper(session->session, handle, feed_values, status, out_values,
300                     run_metadata);
301   ClearDecrefCache();
302 }
303 
TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession * session,int64_t handle,TF_Status * status)304 void TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession* session,
305                                          int64_t handle, TF_Status* status) {
306   Set_TF_Status_from_Status(status, session->session->ReleaseCallable(handle));
307 }
TF_SessionReleaseCallable(TF_Session * session,int64_t handle,TF_Status * status)308 void TF_SessionReleaseCallable(TF_Session* session, int64_t handle,
309                                TF_Status* status) {
310   Set_TF_Status_from_Status(status, session->session->ReleaseCallable(handle));
311 }
312 
313 // Wrapper for TF_PRunSetup that converts the arguments to appropriate types.
314 // If *out_status is OK, the caller becomes the owner of *out_handle.
TF_PRunSetup_wrapper(TF_DeprecatedSession * session,const NameVector & input_names,const NameVector & output_names,const NameVector & target_nodes,TF_Status * out_status,const char ** out_handle)315 void TF_PRunSetup_wrapper(TF_DeprecatedSession* session,
316                           const NameVector& input_names,
317                           const NameVector& output_names,
318                           const NameVector& target_nodes, TF_Status* out_status,
319                           const char** out_handle) {
320   Py_BEGIN_ALLOW_THREADS;
321   TF_PRunSetup(
322       session, const_cast<const char**>(input_names.data()), input_names.size(),
323       const_cast<const char**>(output_names.data()), output_names.size(),
324       const_cast<const char**>(target_nodes.data()), target_nodes.size(),
325       out_handle, out_status);
326   Py_END_ALLOW_THREADS;
327 }
328 
329 // Wrapper for TF_PRun that converts the arguments to appropriate types.
330 // If *out_status is OK, the caller becomes the owner of the PyObjects
331 // in *out_values.
TF_PRun_wrapper(TF_DeprecatedSession * session,const char * handle,PyObject * feed_dict,const NameVector & output_names,TF_Status * out_status,PyObjectVector * out_values)332 void TF_PRun_wrapper(TF_DeprecatedSession* session, const char* handle,
333                      PyObject* feed_dict, const NameVector& output_names,
334                      TF_Status* out_status, PyObjectVector* out_values) {
335   TF_Run_wrapper_helper(session, handle, nullptr, feed_dict, output_names,
336                         NameVector(), out_status, out_values, nullptr);
337   ClearDecrefCache();
338 }
339 
340 // Wrapper for TF_Reset that converts the string vectors to character arrays.
TF_Reset_wrapper(const TF_SessionOptions * opt,const NameVector & containers,TF_Status * status)341 void TF_Reset_wrapper(const TF_SessionOptions* opt,
342                       const NameVector& containers, TF_Status* status) {
343   TF_Reset(opt, const_cast<const char**>(containers.data()), containers.size(),
344            status);
345 }
346 
TF_SessionRun_wrapper_helper(TF_Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,TF_Buffer * run_metadata,TF_Status * out_status,std::vector<PyObject * > * py_outputs)347 void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
348                                   const TF_Buffer* run_options,
349                                   const std::vector<TF_Output>& inputs,
350                                   const std::vector<PyObject*>& input_ndarrays,
351                                   const std::vector<TF_Output>& outputs,
352                                   const std::vector<TF_Operation*>& targets,
353                                   TF_Buffer* run_metadata,
354                                   TF_Status* out_status,
355                                   std::vector<PyObject*>* py_outputs) {
356   DCHECK_EQ(inputs.size(), input_ndarrays.size());
357   DCHECK(py_outputs != nullptr);
358   DCHECK(py_outputs->empty());
359   Status s;
360 
361   // Convert input ndarray PyObjects to TF_Tensors. We maintain a continuous
362   // array of TF_Tensor*s as well as scoped containers to make sure they're
363   // cleaned up properly.
364   //
365   // Memory management:
366   // NdarrayToTensor() creates a new ndarray PyObject from the input
367   // ndarray. We manage the new ndarray's lifetime in order to keep the
368   // underlying data buffer alive (the new ndarray also guarantees a contiguous
369   // data buffer). The new ndarray's data buffer is used to create the
370   // corresponding TF_Tensor. The TF_Tensor's deallocator will queue the new
371   // ndarray to be decref'd by the next ClearDecrefCache() call (we can't call
372   // Py_DECREF in the deallocator directly because the GIL must be held).
373   //
374   // Note that TF_Tensor may directly delegate its data and deallocator to a
375   // TensorBuffer, which may outlive the TF_Tensor (e.g. if the tensor gets
376   // queued or assigned to a variable).
377   TF_TensorVector input_vals;
378   std::vector<Safe_TF_TensorPtr> input_vals_safe;
379   for (PyObject* ndarray : input_ndarrays) {
380     input_vals_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
381     s = NdarrayToTensor(nullptr, ndarray, &input_vals_safe.back());
382     if (!s.ok()) {
383       Set_TF_Status_from_Status(out_status, s);
384       return;
385     }
386     input_vals.push_back(input_vals_safe.back().get());
387   }
388 
389   // Allocate space for output TF_Tensor*s
390   TF_TensorVector output_vals(outputs.size());
391 
392   // Clear up any unused memory leftover from previous runs
393   ClearDecrefCache();
394 
395   // Call TF_SessionRun() (and release GIL during execution)
396   Py_BEGIN_ALLOW_THREADS;
397   if (handle == nullptr) {
398     TF_SessionRun(session, run_options, inputs.data(), input_vals.data(),
399                   inputs.size(), outputs.data(), output_vals.data(),
400                   outputs.size(), targets.data(), targets.size(), run_metadata,
401                   out_status);
402   } else {
403     TF_SessionPRun(session, handle, inputs.data(), input_vals.data(),
404                    inputs.size(), outputs.data(), output_vals.data(),
405                    outputs.size(), targets.data(), targets.size(), out_status);
406   }
407   Py_END_ALLOW_THREADS;
408 
409   // Create scoped containers for output tensors
410   std::vector<Safe_TF_TensorPtr> output_vals_safe;
411   for (TF_Tensor* output : output_vals) {
412     output_vals_safe.emplace_back(make_safe(output));
413   }
414 
415   // Convert outputs to ndarrays (in scoped containers)
416   std::vector<Safe_PyObjectPtr> py_outputs_safe;
417   for (size_t i = 0; i < outputs.size(); ++i) {
418     PyObject* py_array;
419     s = TF_TensorToPyArray(std::move(output_vals_safe[i]), &py_array);
420     if (!s.ok()) {
421       Set_TF_Status_from_Status(out_status, s);
422       return;
423     }
424     py_outputs_safe.emplace_back(
425         make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array))));
426   }
427 
428   // If we reach this point, we have successfully built a list of objects so we
429   // can release them from the safe container into the return vector.
430   for (size_t i = 0; i < outputs.size(); ++i) {
431     py_outputs->push_back(py_outputs_safe[i].release());
432   }
433 }
434 
TF_SessionRun_wrapper(TF_Session * session,const TF_Buffer * run_options,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,TF_Buffer * run_metadata,TF_Status * out_status,std::vector<PyObject * > * py_outputs)435 void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options,
436                            const std::vector<TF_Output>& inputs,
437                            const std::vector<PyObject*>& input_ndarrays,
438                            const std::vector<TF_Output>& outputs,
439                            const std::vector<TF_Operation*>& targets,
440                            TF_Buffer* run_metadata, TF_Status* out_status,
441                            std::vector<PyObject*>* py_outputs) {
442   TF_SessionRun_wrapper_helper(session, nullptr, run_options, inputs,
443                                input_ndarrays, outputs, targets, run_metadata,
444                                out_status, py_outputs);
445   // Release any unused ndarray references (see memory management comment in
446   // TF_SessionRun_wrapper_helper)
447   ClearDecrefCache();
448 }
449 
EqualGraphDefWrapper(const string & actual,const string & expected)450 string EqualGraphDefWrapper(const string& actual, const string& expected) {
451   GraphDef actual_def;
452   if (!actual_def.ParseFromString(actual)) {
453     return "actual is not a valid serialized GraphDef";
454   }
455   GraphDef expected_def;
456   if (!expected_def.ParseFromString(expected)) {
457     return "expected is not a valid serialized GraphDef";
458   }
459   string diff;
460   return EqualGraphDef(actual_def, expected_def, &diff) ? "" : diff;
461 }
462 
EqualAttrValueWrapper(const string & actual,const string & expected)463 string EqualAttrValueWrapper(const string& actual, const string& expected) {
464   AttrValue actual_attr_value;
465   if (!actual_attr_value.ParseFromString(actual)) {
466     return "actual is not a valid serialized AttrValue";
467   }
468 
469   AttrValue expected_attr_value;
470   if (!expected_attr_value.ParseFromString(expected)) {
471     return "expected is not a valid serialized AttrValue";
472   }
473 
474   string diff;
475   if (!AreAttrValuesEqual(actual_attr_value, expected_attr_value)) {
476     diff = strings::Printf(
477         "Actual AttrValue %s does not match Expected AttrValue %s.",
478         SummarizeAttrValue(actual_attr_value).c_str(),
479         SummarizeAttrValue(expected_attr_value).c_str());
480   }
481   return diff;
482 }
483 
484 // Return value set to 6 inlined elements so it fits in a 64-byte cache line.
TF_GraphGetTensorShapeHelper(TF_Graph * graph,TF_Output output,TF_Status * out_status,bool * unknown_shape)485 tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper(
486     TF_Graph* graph, TF_Output output, TF_Status* out_status,
487     bool* unknown_shape) {
488   // Allocate a single variable for holding the result for RVO.
489   tensorflow::gtl::InlinedVector<int64_t, 6> result;
490   *unknown_shape = false;
491   int num_dims = TF_GraphGetTensorNumDims(graph, output, out_status);
492   if (TF_GetCode(out_status) != TF_OK) {
493     return result;
494   }
495   // If shape is unknown, set boolean and return.
496   if (num_dims == -1) {
497     *unknown_shape = true;
498     return result;
499   }
500 
501   // If shape is a scalar, avoid another C call and just return {}.
502   if (num_dims == 0) {
503     return result;
504   }
505 
506   result.resize(num_dims);
507   TF_GraphGetTensorShape(graph, output, result.data(), num_dims, out_status);
508   return result;
509 }
510 
TF_SessionPRunSetup_wrapper(TF_Session * session,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const std::vector<TF_Operation * > & targets,const char ** out_handle,TF_Status * out_status)511 void TF_SessionPRunSetup_wrapper(TF_Session* session,
512                                  const std::vector<TF_Output>& inputs,
513                                  const std::vector<TF_Output>& outputs,
514                                  const std::vector<TF_Operation*>& targets,
515                                  const char** out_handle,
516                                  TF_Status* out_status) {
517   // Call TF_SessionPRunSetup() (and release GIL during execution)
518   Py_BEGIN_ALLOW_THREADS;
519   TF_SessionPRunSetup(session, inputs.data(), inputs.size(), outputs.data(),
520                       outputs.size(), targets.data(), targets.size(),
521                       out_handle, out_status);
522   Py_END_ALLOW_THREADS;
523 }
524 
TF_SessionPRun_wrapper(TF_Session * session,const char * handle,const std::vector<TF_Output> & inputs,const std::vector<PyObject * > & input_ndarrays,const std::vector<TF_Output> & outputs,TF_Status * out_status,std::vector<PyObject * > * py_outputs)525 void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
526                             const std::vector<TF_Output>& inputs,
527                             const std::vector<PyObject*>& input_ndarrays,
528                             const std::vector<TF_Output>& outputs,
529                             TF_Status* out_status,
530                             std::vector<PyObject*>* py_outputs) {
531   const std::vector<TF_Operation*> targets;
532   TF_SessionRun_wrapper_helper(session, handle,
533                                nullptr,  // run_options
534                                inputs, input_ndarrays, outputs, targets,
535                                nullptr,  // run_metadata
536                                out_status, py_outputs);
537   // Release any unused ndarray references (see memory management comment in
538   // TF_SessionRun_wrapper_helper)
539   ClearDecrefCache();
540 }
541 
GetOperationInputs(TF_Operation * oper)542 std::vector<TF_Output> GetOperationInputs(TF_Operation* oper) {
543   int num_inputs = TF_OperationNumInputs(oper);
544   std::vector<TF_Output> inputs(num_inputs);
545   TF_OperationAllInputs(oper, inputs.data(), inputs.size());
546   return inputs;
547 }
548 
TF_OperationGetControlInputs_wrapper(TF_Operation * oper)549 std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
550     TF_Operation* oper) {
551   std::vector<TF_Operation*> control_inputs(TF_OperationNumControlInputs(oper));
552   TF_OperationGetControlInputs(oper, control_inputs.data(),
553                                control_inputs.size());
554   return control_inputs;
555 }
556 
TF_OperationGetControlOutputs_wrapper(TF_Operation * oper)557 std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper(
558     TF_Operation* oper) {
559   std::vector<TF_Operation*> control_outputs(
560       TF_OperationNumControlOutputs(oper));
561   TF_OperationGetControlOutputs(oper, control_outputs.data(),
562                                 control_outputs.size());
563   return control_outputs;
564 }
565 
TF_OperationOutputConsumers_wrapper(TF_Output oper_out)566 std::vector<const char*> TF_OperationOutputConsumers_wrapper(
567     TF_Output oper_out) {
568   int num_consumers = TF_OperationOutputNumConsumers(oper_out);
569   std::vector<TF_Input> consumers(num_consumers);
570   TF_OperationOutputConsumers(oper_out, consumers.data(), num_consumers);
571 
572   std::vector<const char*> consumer_names(num_consumers);
573   for (int i = 0; i < num_consumers; ++i) {
574     consumer_names[i] = TF_OperationName(consumers[i].oper);
575   }
576   return consumer_names;
577 }
578 
TF_GraphToFunction_wrapper(const TF_Graph * fn_body,const char * fn_name,bool append_hash_to_fn_name,const std::vector<TF_Operation * > * opers,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const NameVector & output_names,const std::vector<TF_Operation * > * control_outputs,const NameVector & control_output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * out_status)579 TF_Function* TF_GraphToFunction_wrapper(
580     const TF_Graph* fn_body, const char* fn_name, bool append_hash_to_fn_name,
581     const std::vector<TF_Operation*>* opers,
582     const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs,
583     const NameVector& output_names,
584     const std::vector<TF_Operation*>* control_outputs,
585     const NameVector& control_output_names, const TF_FunctionOptions* opts,
586     const char* description, TF_Status* out_status) {
587   if (!output_names.empty() && output_names.size() != outputs.size()) {
588     Set_TF_Status_from_Status(
589         out_status,
590         errors::InvalidArgument(
591             "output names must be either empty or equal in size to outputs. ",
592             "output names size = ", output_names.size(),
593             " outputs size = ", outputs.size()));
594     return nullptr;
595   }
596 
597   int nopers = -1;
598   const TF_Operation* const* opers_array = nullptr;
599   if (opers != nullptr) {
600     nopers = opers->size();
601     opers_array = opers->data();
602   }
603 
604   const char** output_names_ptr =
605       output_names.empty() ? nullptr
606                            : const_cast<const char**>(output_names.data());
607 
608   const char** control_output_names_ptr =
609       control_output_names.empty()
610           ? nullptr
611           : const_cast<const char**>(control_output_names.data());
612 
613   return TF_GraphToFunctionWithControlOutputs(
614       fn_body, fn_name, append_hash_to_fn_name, nopers, opers_array,
615       inputs.size(), inputs.data(), outputs.size(), outputs.data(),
616       output_names_ptr,
617       control_outputs == nullptr ? 0 : control_outputs->size(),
618       control_outputs == nullptr ? nullptr : control_outputs->data(),
619       control_output_names_ptr, opts, description, out_status);
620 }
621 
TF_GraphSetOutputHandleShapesAndTypes_wrapper(TF_Graph * graph,TF_Output output,const std::vector<std::vector<int64_t>> & shapes,const std::vector<int> & ranks,const std::vector<TF_DataType> & types,TF_Status * status)622 void TF_GraphSetOutputHandleShapesAndTypes_wrapper(
623     TF_Graph* graph, TF_Output output,
624     const std::vector<std::vector<int64_t>>& shapes,
625     const std::vector<int>& ranks, const std::vector<TF_DataType>& types,
626     TF_Status* status) {
627   std::vector<const int64_t*> shapes_pointers(shapes.size());
628   for (int i = 0; i < shapes.size(); ++i) {
629     shapes_pointers[i] = ranks[i] <= 0 ? nullptr : &shapes[i][0];
630   }
631   TF_GraphSetOutputHandleShapesAndTypes(graph, output, shapes.size(),
632                                         shapes_pointers.data(), ranks.data(),
633                                         types.data(), status);
634 }
635 
CreatePlaceholder(TF_Graph * graph,TF_Status * s,string && name,TF_DataType dtype,TF_Output * output)636 void CreatePlaceholder(TF_Graph* graph, TF_Status* s, string&& name,
637                        TF_DataType dtype, TF_Output* output) {
638   TF_OperationDescription* desc =
639       TF_NewOperation(graph, "Placeholder", name.data());
640   TF_SetAttrType(desc, "dtype", dtype);
641   TF_Operation* op = TF_FinishOperation(desc, s);
642   output->oper = op;
643   output->index = 0;
644 }
645 
TF_CreatePlaceholders(TF_Graph * graph,PyObject * dtypes,const char * prefix,TF_Status * status)646 std::vector<TF_Output> TF_CreatePlaceholders(TF_Graph* graph, PyObject* dtypes,
647                                              const char* prefix,
648                                              TF_Status* status) {
649   std::vector<TF_Output> outputs;
650   dtypes = PySequence_Fast(dtypes, "dtypes must be a sequence");
651   if (dtypes == nullptr) {
652     Set_TF_Status_from_Status(status, errors::Internal("dtypes is nullptr"));
653     return outputs;
654   }
655   Safe_PyObjectPtr dtypes_holder(make_safe(dtypes));
656   Py_ssize_t len = PySequence_Fast_GET_SIZE(dtypes);
657   outputs.reserve(len);
658   for (size_t i = 0; i < len; i++) {
659     PyObject* dtype = PySequence_Fast_GET_ITEM(dtypes, i);
660     if (!dtype) {
661       Set_TF_Status_from_Status(status,
662                                 errors::Internal("Could not get dtype ", i));
663       return outputs;
664     }
665 #if PY_MAJOR_VERSION >= 3
666     TF_DataType tf_datatype = static_cast<TF_DataType>(PyLong_AsLong(dtype));
667 #else
668     TF_DataType tf_datatype = static_cast<TF_DataType>(PyInt_AsLong(dtype));
669 #endif
670     outputs.push_back(TF_Output());
671     CreatePlaceholder(graph, status, strings::StrCat(prefix, i), tf_datatype,
672                       &outputs.back());
673     if (!status->status.ok()) break;
674   }
675   return outputs;
676 }
677 
TF_GraphSetTensorShape_wrapper(TF_Graph * graph,TF_Output output,const std::vector<int64_t> & dims,bool unknown_shape,TF_Status * status)678 void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
679                                     const std::vector<int64_t>& dims,
680                                     bool unknown_shape, TF_Status* status) {
681   if (unknown_shape) {
682     TF_GraphSetTensorShape(graph, output, nullptr, -1, status);
683     return;
684   }
685   TF_GraphSetTensorShape(graph, output, dims.data(), dims.size(), status);
686 }
687 
TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(TF_ImportGraphDefResults * results)688 std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
689     TF_ImportGraphDefResults* results) {
690   int num_missing_unused_input_mappings;
691   const char** src_names;
692   int* src_indexes;
693   TF_ImportGraphDefResultsMissingUnusedInputMappings(
694       results, &num_missing_unused_input_mappings, &src_names, &src_indexes);
695   std::vector<string> input_strs(num_missing_unused_input_mappings);
696   for (int i = 0; i < num_missing_unused_input_mappings; ++i) {
697     input_strs[i] = TensorId(src_names[i], src_indexes[i]).ToString();
698   }
699   return input_strs;
700 }
701 
TF_TryEvaluateConstant_wrapper(TF_Graph * graph,TF_Output output,TF_Status * status)702 PyObject* TF_TryEvaluateConstant_wrapper(TF_Graph* graph, TF_Output output,
703                                          TF_Status* status) {
704   TF_Tensor* result_tensor;
705   bool evaluated =
706       TF_TryEvaluateConstant(graph, output, &result_tensor, status);
707   if (!evaluated || TF_GetCode(status) != TF_OK) Py_RETURN_NONE;
708 
709   Safe_TF_TensorPtr safe_result_tensor(result_tensor);
710   PyObject* out;
711   Status s = TF_TensorToPyArray(std::move(safe_result_tensor), &out);
712   Set_TF_Status_from_Status(status, s);
713   if (!s.ok()) Py_RETURN_NONE;
714   return PyArray_Return(reinterpret_cast<PyArrayObject*>(out));
715 }
716 
717 }  // namespace tensorflow
718