• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_C_C_TEST_UTIL_H_
17 #define TENSORFLOW_C_C_TEST_UTIL_H_
18 
19 #include "tensorflow/c/c_api.h"
20 
21 #include <vector>
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/graph.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/platform/test.h"
28 
29 using ::tensorflow::string;
30 
31 typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)>
32     unique_tensor_ptr;
33 
34 TF_Tensor* BoolTensor(int32_t v);
35 
36 // Create a tensor with values of type TF_INT8 provided by `values`.
37 TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values);
38 
39 // Create a tensor with values of type TF_INT32 provided by `values`.
40 TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims,
41                        const int32_t* values);
42 
43 // Create 1 dimensional tensor with values from `values`
44 TF_Tensor* Int32Tensor(const std::vector<int32_t>& values);
45 
46 TF_Tensor* Int32Tensor(int32_t v);
47 
48 TF_Tensor* DoubleTensor(double v);
49 
50 TF_Tensor* FloatTensor(float v);
51 
52 TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
53                           const char* name = "feed",
54                           TF_DataType dtype = TF_INT32,
55                           const std::vector<int64_t>& dims = {});
56 
57 TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
58                     const char* name = "const");
59 
60 TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
61                           const char* name = "scalar");
62 
63 TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
64                           const char* name = "scalar");
65 
66 TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s,
67                           const char* name = "scalar");
68 
69 TF_Operation* ScalarConst(float v, TF_Graph* graph, TF_Status* s,
70                           const char* name = "scalar");
71 
72 TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
73                   TF_Status* s, const char* name = "add");
74 
75 TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
76                          TF_Status* s, const char* name = "add");
77 
78 TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
79                                     TF_Graph* graph, TF_Operation* ctrl_op,
80                                     TF_Status* s, const char* name = "add");
81 
82 TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
83                   const char* name = "add");
84 
85 TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
86                   TF_Status* s, const char* name = "min");
87 
88 TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
89                   TF_Status* s, const char* name = "mul");
90 
91 // If `op_device` is non-empty, set the created op on that device.
92 TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
93                             const string& op_device, TF_Status* s,
94                             const char* name = "min");
95 
96 TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s,
97                   const char* name = "neg");
98 
99 TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s);
100 
101 TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype,
102                             TF_Graph* graph, TF_Status* s);
103 
104 // Split `input` along the first dimension into 3 tensors
105 TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
106                      const char* name = "split3");
107 
108 bool IsPlaceholder(const tensorflow::NodeDef& node_def);
109 
110 bool IsScalarConst(const tensorflow::NodeDef& node_def, int v);
111 
112 bool IsAddN(const tensorflow::NodeDef& node_def, int n);
113 
114 bool IsNeg(const tensorflow::NodeDef& node_def, const string& input);
115 
116 bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
117 
118 bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def);
119 
120 bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def);
121 
122 bool GetAttrValue(TF_Operation* oper, const char* attr_name,
123                   tensorflow::AttrValue* attr_value, TF_Status* s);
124 
125 // Returns a sorted vector of std::pair<function_name, gradient_func> from
126 // graph_def.library().gradient()
127 std::vector<std::pair<string, string>> GetGradDefs(
128     const tensorflow::GraphDef& graph_def);
129 
130 // Returns a sorted vector of names contained in `grad_def`
131 std::vector<string> GetFuncNames(const tensorflow::GraphDef& graph_def);
132 
133 class CSession {
134  public:
135   CSession(TF_Graph* graph, TF_Status* s, bool use_XLA = false);
136   explicit CSession(TF_Session* session);
137 
138   ~CSession();
139 
140   void SetInputs(std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs);
141   void SetOutputs(std::initializer_list<TF_Operation*> outputs);
142   void SetOutputs(const std::vector<TF_Output>& outputs);
143   void SetTargets(std::initializer_list<TF_Operation*> targets);
144 
145   void Run(TF_Status* s);
146 
147   void CloseAndDelete(TF_Status* s);
148 
output_tensor(int i)149   TF_Tensor* output_tensor(int i) { return output_values_[i]; }
150 
mutable_session()151   TF_Session* mutable_session() { return session_; }
152 
153  private:
154   void DeleteInputValues();
155   void ResetOutputValues();
156 
157   TF_Session* session_;
158   std::vector<TF_Output> inputs_;
159   std::vector<TF_Tensor*> input_values_;
160   std::vector<TF_Output> outputs_;
161   std::vector<TF_Tensor*> output_values_;
162   std::vector<TF_Operation*> targets_;
163 };
164 
165 #endif  // TENSORFLOW_C_C_TEST_UTIL_H_
166