1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_C_C_TEST_UTIL_H_ 17 #define TENSORFLOW_C_C_TEST_UTIL_H_ 18 19 #include "tensorflow/c/c_api.h" 20 21 #include <vector> 22 #include "tensorflow/core/framework/attr_value.pb.h" 23 #include "tensorflow/core/framework/function.pb.h" 24 #include "tensorflow/core/framework/graph.pb.h" 25 #include "tensorflow/core/framework/node_def.pb.h" 26 #include "tensorflow/core/framework/types.pb.h" 27 #include "tensorflow/core/platform/test.h" 28 29 using ::tensorflow::string; 30 31 typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> 32 unique_tensor_ptr; 33 34 TF_Tensor* BoolTensor(int32_t v); 35 36 // Create a tensor with values of type TF_INT8 provided by `values`. 37 TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values); 38 39 // Create a tensor with values of type TF_INT32 provided by `values`. 40 TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims, 41 const int32_t* values); 42 43 // Create 1 dimensional tensor with values from `values` 44 TF_Tensor* Int32Tensor(const std::vector<int32_t>& values); 45 46 TF_Tensor* Int32Tensor(int32_t v); 47 48 TF_Tensor* DoubleTensor(double v); 49 50 TF_Tensor* FloatTensor(float v); 51 52 TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, 53 const char* name = "feed", 54 TF_DataType dtype = TF_INT32, 55 const std::vector<int64_t>& dims = {}); 56 57 TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, 58 const char* name = "const"); 59 60 TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s, 61 const char* name = "scalar"); 62 63 TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, 64 const char* name = "scalar"); 65 66 TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s, 67 const char* name = "scalar"); 68 69 TF_Operation* ScalarConst(float v, TF_Graph* graph, TF_Status* s, 70 const char* name = "scalar"); 71 72 TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, 73 TF_Status* s, const char* name = "add"); 74 75 TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph, 76 TF_Status* s, const char* name = "add"); 77 78 TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r, 79 TF_Graph* graph, TF_Operation* ctrl_op, 80 TF_Status* s, const char* name = "add"); 81 82 TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, 83 const char* name = "add"); 84 85 TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, 86 TF_Status* s, const char* name = "min"); 87 88 TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph, 89 TF_Status* s, const char* name = "mul"); 90 91 // If `op_device` is non-empty, set the created op on that device. 92 TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph, 93 const string& op_device, TF_Status* s, 94 const char* name = "min"); 95 96 TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s, 97 const char* name = "neg"); 98 99 TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s); 100 101 TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype, 102 TF_Graph* graph, TF_Status* s); 103 104 // Split `input` along the first dimension into 3 tensors 105 TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s, 106 const char* name = "split3"); 107 108 bool IsPlaceholder(const tensorflow::NodeDef& node_def); 109 110 bool IsScalarConst(const tensorflow::NodeDef& node_def, int v); 111 112 bool IsAddN(const tensorflow::NodeDef& node_def, int n); 113 114 bool IsNeg(const tensorflow::NodeDef& node_def, const string& input); 115 116 bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def); 117 118 bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def); 119 120 bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def); 121 122 bool GetAttrValue(TF_Operation* oper, const char* attr_name, 123 tensorflow::AttrValue* attr_value, TF_Status* s); 124 125 // Returns a sorted vector of std::pair<function_name, gradient_func> from 126 // graph_def.library().gradient() 127 std::vector<std::pair<string, string>> GetGradDefs( 128 const tensorflow::GraphDef& graph_def); 129 130 // Returns a sorted vector of names contained in `grad_def` 131 std::vector<string> GetFuncNames(const tensorflow::GraphDef& graph_def); 132 133 class CSession { 134 public: 135 CSession(TF_Graph* graph, TF_Status* s, bool use_XLA = false); 136 explicit CSession(TF_Session* session); 137 138 ~CSession(); 139 140 void SetInputs(std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs); 141 void SetOutputs(std::initializer_list<TF_Operation*> outputs); 142 void SetOutputs(const std::vector<TF_Output>& outputs); 143 void SetTargets(std::initializer_list<TF_Operation*> targets); 144 145 void Run(TF_Status* s); 146 147 void CloseAndDelete(TF_Status* s); 148 output_tensor(int i)149 TF_Tensor* output_tensor(int i) { return output_values_[i]; } 150 mutable_session()151 TF_Session* mutable_session() { return session_; } 152 153 private: 154 void DeleteInputValues(); 155 void ResetOutputValues(); 156 157 TF_Session* session_; 158 std::vector<TF_Output> inputs_; 159 std::vector<TF_Tensor*> input_values_; 160 std::vector<TF_Output> outputs_; 161 std::vector<TF_Tensor*> output_values_; 162 std::vector<TF_Operation*> targets_; 163 }; 164 165 #endif // TENSORFLOW_C_C_TEST_UTIL_H_ 166