1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_test_util.h"
17
18 #include "tensorflow/c/c_api_experimental.h"
19 #include "tensorflow/core/framework/function.pb.h"
20 #include "tensorflow/core/framework/op_def.pb.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/core/platform/strcat.h"
24 #include "tensorflow/core/public/session_options.h"
25
26 using tensorflow::GraphDef;
27 using tensorflow::NodeDef;
28
BoolDeallocator(void * data,size_t,void * arg)29 static void BoolDeallocator(void* data, size_t, void* arg) {
30 delete[] static_cast<bool*>(data);
31 }
32
Int32Deallocator(void * data,size_t,void * arg)33 static void Int32Deallocator(void* data, size_t, void* arg) {
34 delete[] static_cast<int32_t*>(data);
35 }
36
DoubleDeallocator(void * data,size_t,void * arg)37 static void DoubleDeallocator(void* data, size_t, void* arg) {
38 delete[] static_cast<double*>(data);
39 }
40
FloatDeallocator(void * data,size_t,void * arg)41 static void FloatDeallocator(void* data, size_t, void* arg) {
42 delete[] static_cast<float*>(data);
43 }
44
BoolTensor(bool v)45 TF_Tensor* BoolTensor(bool v) {
46 const int num_bytes = sizeof(bool);
47 bool* values = new bool[1];
48 values[0] = v;
49 return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator,
50 nullptr);
51 }
52
Int8Tensor(const int64_t * dims,int num_dims,const char * values)53 TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
54 int64_t num_values = 1;
55 for (int i = 0; i < num_dims; ++i) {
56 num_values *= dims[i];
57 }
58 TF_Tensor* t =
59 TF_AllocateTensor(TF_INT8, dims, num_dims, sizeof(char) * num_values);
60 memcpy(TF_TensorData(t), values, sizeof(char) * num_values);
61 return t;
62 }
63
Int32Tensor(const int64_t * dims,int num_dims,const int32_t * values)64 TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims,
65 const int32_t* values) {
66 int64_t num_values = 1;
67 for (int i = 0; i < num_dims; ++i) {
68 num_values *= dims[i];
69 }
70 TF_Tensor* t =
71 TF_AllocateTensor(TF_INT32, dims, num_dims, sizeof(int32_t) * num_values);
72 memcpy(TF_TensorData(t), values, sizeof(int32_t) * num_values);
73 return t;
74 }
75
Int32Tensor(const std::vector<int32_t> & values)76 TF_Tensor* Int32Tensor(const std::vector<int32_t>& values) {
77 int64_t dims = values.size();
78 return Int32Tensor(&dims, 1, values.data());
79 }
80
Int32Tensor(int32_t v)81 TF_Tensor* Int32Tensor(int32_t v) {
82 const int num_bytes = sizeof(int32_t);
83 int32_t* values = new int32_t[1];
84 values[0] = v;
85 return TF_NewTensor(TF_INT32, nullptr, 0, values, num_bytes,
86 &Int32Deallocator, nullptr);
87 }
88
DoubleTensor(double v)89 TF_Tensor* DoubleTensor(double v) {
90 const int num_bytes = sizeof(double);
91 double* values = new double[1];
92 values[0] = v;
93 return TF_NewTensor(TF_DOUBLE, nullptr, 0, values, num_bytes,
94 &DoubleDeallocator, nullptr);
95 }
96
FloatTensor(float v)97 TF_Tensor* FloatTensor(float v) {
98 const int num_bytes = sizeof(float);
99 float* values = new float[1];
100 values[0] = v;
101 return TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes,
102 &FloatDeallocator, nullptr);
103 }
104
105 // All the *Helper methods are used as a workaround for the restrictions that
106 // one cannot call ASSERT_* methods in non-void-returning functions (when
107 // exceptions are disabled during compilation)
PlaceholderHelper(TF_Graph * graph,TF_Status * s,const char * name,TF_DataType dtype,const std::vector<int64_t> & dims,TF_Operation ** op)108 void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
109 TF_DataType dtype, const std::vector<int64_t>& dims,
110 TF_Operation** op) {
111 TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
112 TF_SetAttrType(desc, "dtype", dtype);
113 if (!dims.empty()) {
114 TF_SetAttrShape(desc, "shape", dims.data(), dims.size());
115 }
116 *op = TF_FinishOperation(desc, s);
117 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
118 ASSERT_NE(*op, nullptr);
119 }
120
Placeholder(TF_Graph * graph,TF_Status * s,const char * name,TF_DataType dtype,const std::vector<int64_t> & dims)121 TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name,
122 TF_DataType dtype, const std::vector<int64_t>& dims) {
123 TF_Operation* op;
124 PlaceholderHelper(graph, s, name, dtype, dims, &op);
125 return op;
126 }
127
ConstHelper(TF_Tensor * t,TF_Graph * graph,TF_Status * s,const char * name,TF_Operation ** op)128 void ConstHelper(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name,
129 TF_Operation** op) {
130 TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
131 TF_SetAttrTensor(desc, "value", t, s);
132 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
133 TF_SetAttrType(desc, "dtype", TF_TensorType(t));
134 *op = TF_FinishOperation(desc, s);
135 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
136 ASSERT_NE(*op, nullptr);
137 }
138
Const(TF_Tensor * t,TF_Graph * graph,TF_Status * s,const char * name)139 TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
140 const char* name) {
141 TF_Operation* op;
142 ConstHelper(t, graph, s, name, &op);
143 return op;
144 }
145
ScalarConst(bool v,TF_Graph * graph,TF_Status * s,const char * name)146 TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
147 const char* name) {
148 unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor);
149 return Const(tensor.get(), graph, s, name);
150 }
151
ScalarConst(int32_t v,TF_Graph * graph,TF_Status * s,const char * name)152 TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
153 const char* name) {
154 unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor);
155 return Const(tensor.get(), graph, s, name);
156 }
157
ScalarConst(double v,TF_Graph * graph,TF_Status * s,const char * name)158 TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s,
159 const char* name) {
160 unique_tensor_ptr tensor(DoubleTensor(v), TF_DeleteTensor);
161 return Const(tensor.get(), graph, s, name);
162 }
163
ScalarConst(float v,TF_Graph * graph,TF_Status * s,const char * name)164 TF_Operation* ScalarConst(float v, TF_Graph* graph, TF_Status* s,
165 const char* name) {
166 unique_tensor_ptr tensor(FloatTensor(v), TF_DeleteTensor);
167 return Const(tensor.get(), graph, s, name);
168 }
169
AddOpHelper(TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Status * s,const char * name,TF_Operation ** op,bool check)170 void AddOpHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
171 TF_Status* s, const char* name, TF_Operation** op,
172 bool check) {
173 TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
174 TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
175 TF_AddInputList(desc, add_inputs, 2);
176 *op = TF_FinishOperation(desc, s);
177 if (check) {
178 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
179 ASSERT_NE(*op, nullptr);
180 }
181 }
182
Add(TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Status * s,const char * name)183 TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
184 TF_Status* s, const char* name) {
185 TF_Operation* op;
186 AddOpHelper(l, r, graph, s, name, &op, true);
187 return op;
188 }
189
AddNoCheck(TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Status * s,const char * name)190 TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
191 TF_Status* s, const char* name) {
192 TF_Operation* op;
193 AddOpHelper(l, r, graph, s, name, &op, false);
194 return op;
195 }
196
AddWithCtrlDependency(TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Operation * ctrl_op,TF_Status * s,const char * name)197 TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
198 TF_Graph* graph, TF_Operation* ctrl_op,
199 TF_Status* s, const char* name) {
200 TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
201 TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
202 TF_AddInputList(desc, add_inputs, 2);
203 TF_AddControlInput(desc, ctrl_op);
204 return TF_FinishOperation(desc, s);
205 }
206
207 // If `op_device` is non-empty, set the created op on that device.
BinaryOpHelper(const char * op_name,TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Status * s,const char * name,TF_Operation ** op,const string & op_device,bool check)208 void BinaryOpHelper(const char* op_name, TF_Operation* l, TF_Operation* r,
209 TF_Graph* graph, TF_Status* s, const char* name,
210 TF_Operation** op, const string& op_device, bool check) {
211 TF_OperationDescription* desc = TF_NewOperation(graph, op_name, name);
212 if (!op_device.empty()) {
213 TF_SetDevice(desc, op_device.c_str());
214 }
215 TF_AddInput(desc, {l, 0});
216 TF_AddInput(desc, {r, 0});
217 *op = TF_FinishOperation(desc, s);
218 if (check) {
219 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
220 ASSERT_NE(*op, nullptr);
221 }
222 }
223
MinWithDevice(TF_Operation * l,TF_Operation * r,TF_Graph * graph,const string & op_device,TF_Status * s,const char * name)224 TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
225 const string& op_device, TF_Status* s,
226 const char* name) {
227 TF_Operation* op;
228 BinaryOpHelper("Min", l, r, graph, s, name, &op, op_device, true);
229 return op;
230 }
231
Min(TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Status * s,const char * name)232 TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
233 TF_Status* s, const char* name) {
234 return MinWithDevice(l, r, graph, /*op_device=*/"", s, name);
235 }
236
Mul(TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Status * s,const char * name)237 TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
238 TF_Status* s, const char* name) {
239 TF_Operation* op;
240 BinaryOpHelper("Mul", l, r, graph, s, name, &op, "", true);
241 return op;
242 }
243
Add(TF_Output l,TF_Output r,TF_Graph * graph,TF_Status * s,const char * name)244 TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
245 const char* name) {
246 TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
247 TF_Output inputs[2] = {l, r};
248 TF_AddInputList(desc, inputs, 2);
249 return TF_FinishOperation(desc, s);
250 }
251
NegHelper(TF_Operation * n,TF_Graph * graph,TF_Status * s,const char * name,TF_Operation ** op)252 void NegHelper(TF_Operation* n, TF_Graph* graph, TF_Status* s, const char* name,
253 TF_Operation** op) {
254 TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", name);
255 TF_Output neg_input = {n, 0};
256 TF_AddInput(desc, neg_input);
257 *op = TF_FinishOperation(desc, s);
258 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
259 ASSERT_NE(*op, nullptr);
260 }
261
Neg(TF_Operation * n,TF_Graph * graph,TF_Status * s,const char * name)262 TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s,
263 const char* name) {
264 TF_Operation* op;
265 NegHelper(n, graph, s, name, &op);
266 return op;
267 }
268
LessThan(TF_Output l,TF_Output r,TF_Graph * graph,TF_Status * s)269 TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
270 TF_Status* s) {
271 TF_OperationDescription* desc = TF_NewOperation(graph, "Less", "less_than");
272 TF_AddInput(desc, l);
273 TF_AddInput(desc, r);
274 return TF_FinishOperation(desc, s);
275 }
276
RandomUniform(TF_Operation * shape,TF_DataType dtype,TF_Graph * graph,TF_Status * s)277 TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype,
278 TF_Graph* graph, TF_Status* s) {
279 TF_OperationDescription* desc =
280 TF_NewOperation(graph, "RandomUniform", "random_uniform");
281 TF_AddInput(desc, {shape, 0});
282 TF_SetAttrType(desc, "dtype", dtype);
283 return TF_FinishOperation(desc, s);
284 }
285
Split3Helper(TF_Operation * input,TF_Graph * graph,TF_Status * s,const char * name,TF_Operation ** op)286 void Split3Helper(TF_Operation* input, TF_Graph* graph, TF_Status* s,
287 const char* name, TF_Operation** op) {
288 TF_Operation* zero = ScalarConst(
289 0, graph, s, ::tensorflow::strings::StrCat(name, "_const0").c_str());
290 TF_OperationDescription* desc = TF_NewOperation(graph, "Split", name);
291 TF_AddInput(desc, {zero, 0});
292 TF_AddInput(desc, {input, 0});
293 TF_SetAttrInt(desc, "num_split", 3);
294 TF_SetAttrType(desc, "T", TF_INT32);
295 // Set device to CPU since there is no version of split for int32 on GPU
296 // TODO(iga): Convert all these helpers and tests to use floats because
297 // they are usually available on GPUs. After doing this, remove TF_SetDevice
298 // call in c_api_function_test.cc
299 TF_SetDevice(desc, "/cpu:0");
300 *op = TF_FinishOperation(desc, s);
301 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
302 ASSERT_NE(*op, nullptr);
303 }
304
Split3(TF_Operation * input,TF_Graph * graph,TF_Status * s,const char * name)305 TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
306 const char* name) {
307 TF_Operation* op;
308 Split3Helper(input, graph, s, name, &op);
309 return op;
310 }
311
IsPlaceholder(const tensorflow::NodeDef & node_def)312 bool IsPlaceholder(const tensorflow::NodeDef& node_def) {
313 if (node_def.op() != "Placeholder" || node_def.name() != "feed") {
314 return false;
315 }
316 bool found_dtype = false;
317 bool found_shape = false;
318 for (const auto& attr : node_def.attr()) {
319 if (attr.first == "dtype") {
320 if (attr.second.type() == tensorflow::DT_INT32) {
321 found_dtype = true;
322 } else {
323 return false;
324 }
325 } else if (attr.first == "shape") {
326 found_shape = true;
327 }
328 }
329 return found_dtype && found_shape;
330 }
331
IsScalarConst(const tensorflow::NodeDef & node_def,int v)332 bool IsScalarConst(const tensorflow::NodeDef& node_def, int v) {
333 if (node_def.op() != "Const" || node_def.name() != "scalar") {
334 return false;
335 }
336 bool found_dtype = false;
337 bool found_value = false;
338 for (const auto& attr : node_def.attr()) {
339 if (attr.first == "dtype") {
340 if (attr.second.type() == tensorflow::DT_INT32) {
341 found_dtype = true;
342 } else {
343 return false;
344 }
345 } else if (attr.first == "value") {
346 if (attr.second.has_tensor() &&
347 attr.second.tensor().int_val_size() == 1 &&
348 attr.second.tensor().int_val(0) == v) {
349 found_value = true;
350 } else {
351 return false;
352 }
353 }
354 }
355 return found_dtype && found_value;
356 }
357
IsAddN(const tensorflow::NodeDef & node_def,int n)358 bool IsAddN(const tensorflow::NodeDef& node_def, int n) {
359 if (node_def.op() != "AddN" || node_def.name() != "add" ||
360 node_def.input_size() != n) {
361 return false;
362 }
363 bool found_t = false;
364 bool found_n = false;
365 for (const auto& attr : node_def.attr()) {
366 if (attr.first == "T") {
367 if (attr.second.type() == tensorflow::DT_INT32) {
368 found_t = true;
369 } else {
370 return false;
371 }
372 } else if (attr.first == "N") {
373 if (attr.second.i() == n) {
374 found_n = true;
375 } else {
376 return false;
377 }
378 }
379 }
380 return found_t && found_n;
381 }
382
IsNeg(const tensorflow::NodeDef & node_def,const string & input)383 bool IsNeg(const tensorflow::NodeDef& node_def, const string& input) {
384 return node_def.op() == "Neg" && node_def.name() == "neg" &&
385 node_def.input_size() == 1 && node_def.input(0) == input;
386 }
387
GetGraphDef(TF_Graph * graph,tensorflow::GraphDef * graph_def)388 bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def) {
389 TF_Status* s = TF_NewStatus();
390 TF_Buffer* buffer = TF_NewBuffer();
391 TF_GraphToGraphDef(graph, buffer, s);
392 bool ret = TF_GetCode(s) == TF_OK;
393 EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
394 if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length);
395 TF_DeleteBuffer(buffer);
396 TF_DeleteStatus(s);
397 return ret;
398 }
399
GetNodeDef(TF_Operation * oper,tensorflow::NodeDef * node_def)400 bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def) {
401 TF_Status* s = TF_NewStatus();
402 TF_Buffer* buffer = TF_NewBuffer();
403 TF_OperationToNodeDef(oper, buffer, s);
404 bool ret = TF_GetCode(s) == TF_OK;
405 EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
406 if (ret) ret = node_def->ParseFromArray(buffer->data, buffer->length);
407 TF_DeleteBuffer(buffer);
408 TF_DeleteStatus(s);
409 return ret;
410 }
411
GetFunctionDef(TF_Function * func,tensorflow::FunctionDef * func_def)412 bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def) {
413 TF_Status* s = TF_NewStatus();
414 TF_Buffer* buffer = TF_NewBuffer();
415 TF_FunctionToFunctionDef(func, buffer, s);
416 bool ret = TF_GetCode(s) == TF_OK;
417 EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
418 if (ret) ret = func_def->ParseFromArray(buffer->data, buffer->length);
419 TF_DeleteBuffer(buffer);
420 TF_DeleteStatus(s);
421 return ret;
422 }
423
GetAttrValue(TF_Operation * oper,const char * attr_name,tensorflow::AttrValue * attr_value,TF_Status * s)424 bool GetAttrValue(TF_Operation* oper, const char* attr_name,
425 tensorflow::AttrValue* attr_value, TF_Status* s) {
426 TF_Buffer* buffer = TF_NewBuffer();
427 TF_OperationGetAttrValueProto(oper, attr_name, buffer, s);
428 bool ret = TF_GetCode(s) == TF_OK;
429 if (ret) ret = attr_value->ParseFromArray(buffer->data, buffer->length);
430 TF_DeleteBuffer(buffer);
431 return ret;
432 }
433
GetGradDefs(const tensorflow::GraphDef & graph_def)434 std::vector<std::pair<string, string>> GetGradDefs(
435 const tensorflow::GraphDef& graph_def) {
436 std::vector<std::pair<string, string>> grads;
437 for (const tensorflow::GradientDef& grad : graph_def.library().gradient()) {
438 grads.emplace_back(grad.function_name(), grad.gradient_func());
439 }
440 std::sort(grads.begin(), grads.end());
441 return grads;
442 }
443
GetFuncNames(const tensorflow::GraphDef & graph_def)444 std::vector<string> GetFuncNames(const tensorflow::GraphDef& graph_def) {
445 std::vector<string> names;
446 for (const tensorflow::FunctionDef& func : graph_def.library().function()) {
447 names.push_back(func.signature().name());
448 }
449 std::sort(names.begin(), names.end());
450 return names;
451 }
452
CSession(TF_Graph * graph,TF_Status * s,bool use_XLA)453 CSession::CSession(TF_Graph* graph, TF_Status* s, bool use_XLA) {
454 TF_SessionOptions* opts = TF_NewSessionOptions();
455 TF_EnableXLACompilation(opts, use_XLA);
456 session_ = TF_NewSession(graph, opts, s);
457 TF_DeleteSessionOptions(opts);
458 }
459
CSession(TF_Session * session)460 CSession::CSession(TF_Session* session) : session_(session) {}
461
~CSession()462 CSession::~CSession() {
463 TF_Status* s = TF_NewStatus();
464 CloseAndDelete(s);
465 EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
466 TF_DeleteStatus(s);
467 }
468
SetInputs(std::vector<std::pair<TF_Operation *,TF_Tensor * >> inputs)469 void CSession::SetInputs(
470 std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs) {
471 DeleteInputValues();
472 inputs_.clear();
473 for (const auto& p : inputs) {
474 inputs_.emplace_back(TF_Output{p.first, 0});
475 input_values_.emplace_back(p.second);
476 }
477 }
478
SetOutputs(std::initializer_list<TF_Operation * > outputs)479 void CSession::SetOutputs(std::initializer_list<TF_Operation*> outputs) {
480 ResetOutputValues();
481 outputs_.clear();
482 for (TF_Operation* o : outputs) {
483 outputs_.emplace_back(TF_Output{o, 0});
484 }
485 output_values_.resize(outputs_.size());
486 }
487
SetOutputs(const std::vector<TF_Output> & outputs)488 void CSession::SetOutputs(const std::vector<TF_Output>& outputs) {
489 ResetOutputValues();
490 outputs_ = outputs;
491 output_values_.resize(outputs_.size());
492 }
493
SetTargets(std::initializer_list<TF_Operation * > targets)494 void CSession::SetTargets(std::initializer_list<TF_Operation*> targets) {
495 targets_.clear();
496 for (TF_Operation* t : targets) {
497 targets_.emplace_back(t);
498 }
499 }
500
Run(TF_Status * s)501 void CSession::Run(TF_Status* s) {
502 if (inputs_.size() != input_values_.size()) {
503 ADD_FAILURE() << "Call SetInputs() before Run()";
504 return;
505 }
506 ResetOutputValues();
507 output_values_.resize(outputs_.size(), nullptr);
508
509 const TF_Output* inputs_ptr = inputs_.empty() ? nullptr : &inputs_[0];
510 TF_Tensor* const* input_values_ptr =
511 input_values_.empty() ? nullptr : &input_values_[0];
512
513 const TF_Output* outputs_ptr = outputs_.empty() ? nullptr : &outputs_[0];
514 TF_Tensor** output_values_ptr =
515 output_values_.empty() ? nullptr : &output_values_[0];
516
517 TF_Operation* const* targets_ptr = targets_.empty() ? nullptr : &targets_[0];
518
519 TF_SessionRun(session_, nullptr, inputs_ptr, input_values_ptr, inputs_.size(),
520 outputs_ptr, output_values_ptr, outputs_.size(), targets_ptr,
521 targets_.size(), nullptr, s);
522
523 DeleteInputValues();
524 }
525
CloseAndDelete(TF_Status * s)526 void CSession::CloseAndDelete(TF_Status* s) {
527 DeleteInputValues();
528 ResetOutputValues();
529 if (session_ != nullptr) {
530 TF_CloseSession(session_, s);
531 EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
532 TF_DeleteSession(session_, s);
533 session_ = nullptr;
534 }
535 }
536
DeleteInputValues()537 void CSession::DeleteInputValues() {
538 for (size_t i = 0; i < input_values_.size(); ++i) {
539 TF_DeleteTensor(input_values_[i]);
540 }
541 input_values_.clear();
542 }
543
ResetOutputValues()544 void CSession::ResetOutputValues() {
545 for (size_t i = 0; i < output_values_.size(); ++i) {
546 if (output_values_[i] != nullptr) TF_DeleteTensor(output_values_[i]);
547 }
548 output_values_.clear();
549 }
550