1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_C_C_API_INTERNAL_H_ 17 #define TENSORFLOW_C_C_API_INTERNAL_H_ 18 19 #include "tensorflow/c/c_api.h" 20 21 #include <list> 22 #include <set> 23 #include <string> 24 #include <unordered_map> 25 #include <vector> 26 27 // Required for IS_MOBILE_PLATFORM 28 #include "tensorflow/core/platform/platform.h" // NO_LINT 29 30 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) 31 #include "tensorflow/core/framework/op_gen_lib.h" 32 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) 33 #include "tensorflow/core/common_runtime/shape_refiner.h" 34 #include "tensorflow/core/framework/tensor.h" 35 #include "tensorflow/core/framework/tensor_shape.h" 36 #include "tensorflow/core/graph/graph.h" 37 #include "tensorflow/core/graph/graph_constructor.h" 38 #include "tensorflow/core/graph/node_builder.h" 39 #include "tensorflow/core/lib/core/status.h" 40 #include "tensorflow/core/platform/mutex.h" 41 #include "tensorflow/core/platform/types.h" 42 #include "tensorflow/core/public/session.h" 43 44 namespace tensorflow { 45 class Device; 46 class DeviceMgr; 47 class ServerInterface; 48 } // namespace tensorflow 49 50 // Internal structures used by the C API. These are likely to change and should 51 // not be depended on. 52 53 struct TF_Status { 54 tensorflow::Status status; 55 }; 56 57 struct TF_Tensor { 58 ~TF_Tensor(); 59 60 TF_DataType dtype; 61 tensorflow::TensorShape shape; 62 tensorflow::TensorBuffer* buffer; 63 }; 64 65 struct TF_SessionOptions { 66 tensorflow::SessionOptions options; 67 }; 68 69 struct TF_DeprecatedSession { 70 tensorflow::Session* session; 71 }; 72 73 struct TF_Library { 74 void* lib_handle; 75 TF_Buffer op_list; 76 }; 77 78 struct TF_Graph { 79 TF_Graph(); 80 81 tensorflow::mutex mu; 82 tensorflow::Graph graph GUARDED_BY(mu); 83 84 // Runs shape inference. 85 tensorflow::ShapeRefiner refiner GUARDED_BY(mu); 86 87 // Maps from name of an operation to the Node* in 'graph'. 88 std::unordered_map<tensorflow::string, tensorflow::Node*> name_map 89 GUARDED_BY(mu); 90 91 // The keys of this map are all the active sessions using this graph. Each 92 // value records whether the graph has been mutated since the corresponding 93 // session has been run (this is detected in RecordMutation function). If the 94 // string is empty, no mutation has occurred. Otherwise the string is a 95 // description of the mutation suitable for returning to the user. 96 // 97 // Sessions are added to this map in TF_NewSession, and removed in 98 // TF_DeleteSession. 99 // TF_Graph may only / must be deleted when 100 // sessions.size() == 0 && delete_requested == true 101 // 102 // TODO(b/74949947): mutations currently trigger a warning instead of a bad 103 // status, this should be reverted when possible. 104 tensorflow::gtl::FlatMap<TF_Session*, tensorflow::string> sessions 105 GUARDED_BY(mu); 106 bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph 107 108 // Used to link graphs contained in TF_WhileParams to the parent graph that 109 // will eventually contain the full while loop. 110 TF_Graph* parent; 111 TF_Output* parent_inputs; 112 }; 113 114 struct TF_OperationDescription { TF_OperationDescriptionTF_OperationDescription115 TF_OperationDescription(TF_Graph* g, const char* op_type, 116 const char* node_name) 117 : node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {} 118 119 tensorflow::NodeBuilder node_builder; 120 TF_Graph* graph; 121 std::set<tensorflow::string> colocation_constraints; 122 }; 123 124 struct TF_Operation { 125 tensorflow::Node node; 126 }; 127 128 struct TF_Session { 129 TF_Session(tensorflow::Session* s, TF_Graph* g); 130 131 tensorflow::Session* session; 132 TF_Graph* const graph; 133 134 tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu); 135 int last_num_graph_nodes; 136 137 // If true, TF_SessionRun and similar methods will call 138 // ExtendSessionGraphHelper before running the graph (this is the default 139 // public behavior). Can be set to false if the caller needs to call 140 // ExtendSessionGraphHelper manually. 141 std::atomic<bool> extend_before_run; 142 }; 143 144 struct TF_ImportGraphDefOptions { 145 tensorflow::ImportGraphDefOptions opts; 146 147 // Backing memory for TensorId fields in opts. 148 // TODO(skyewm): it'd be better if ImportGraphDefOptions owned this. 149 std::list<tensorflow::string> tensor_id_data; 150 }; 151 152 struct TF_ImportGraphDefResults { 153 std::vector<TF_Output> return_tensors; 154 std::vector<TF_Operation*> return_nodes; 155 std::vector<const char*> missing_unused_key_names; 156 std::vector<int> missing_unused_key_indexes; 157 158 // Backing memory for missing_unused_key_names values. 159 std::list<tensorflow::string> missing_unused_key_names_data; 160 }; 161 162 struct TF_DeviceList { 163 std::vector<tensorflow::DeviceAttributes> response; 164 }; 165 166 struct TF_Function { 167 tensorflow::FunctionDef fdef; 168 }; 169 170 struct TF_ApiDefMap { TF_ApiDefMapTF_ApiDefMap171 explicit TF_ApiDefMap(const tensorflow::OpList& op_list) 172 : 173 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) 174 api_def_map(op_list), 175 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) 176 update_docs_called(false) { 177 } 178 179 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) 180 tensorflow::ApiDefMap api_def_map GUARDED_BY(lock); 181 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) 182 bool update_docs_called GUARDED_BY(lock); 183 tensorflow::mutex lock; 184 }; 185 186 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) 187 struct TF_Server { 188 TF_Server(std::unique_ptr<tensorflow::ServerInterface> server); 189 190 const tensorflow::string target; 191 std::unique_ptr<tensorflow::ServerInterface> server; 192 }; 193 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) 194 195 namespace tensorflow { 196 197 class TensorCApi { 198 public: Buffer(const Tensor & tensor)199 static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; } MakeTensor(TF_DataType type,const TensorShape & shape,TensorBuffer * buf)200 static Tensor MakeTensor(TF_DataType type, const TensorShape& shape, 201 TensorBuffer* buf) { 202 return Tensor(static_cast<DataType>(type), shape, buf); 203 } 204 }; 205 206 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); 207 208 TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); 209 210 Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, 211 TF_Buffer* out); 212 213 // Set the shapes and types of the output's handle. 214 // 215 // The lengths of the arrays pointed to by `shapes`, `ranks`, and `types` must 216 // all be equal to `num_shapes_and_types`. If `ranks[i] != -1`, (i.e., if the 217 // rank is known), then it must be equal to the length of `shapes[i]`; if 218 // `ranks[i] == 1`, then `shapes[i]` may be nullptr. 219 // 220 // TODO(akshayka): Implement a corresponding getter method. 221 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, 222 int num_shapes_and_types, 223 const int64_t** shapes, 224 const int* ranks, 225 const TF_DataType* types, 226 TF_Status* status); 227 228 void RecordMutation(TF_Graph* graph, const TF_Operation& op, 229 const char* mutation_type) 230 EXCLUSIVE_LOCKS_REQUIRED(graph->mu); 231 232 bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) 233 LOCKS_EXCLUDED(session->graph->mu, session->mu); 234 235 std::string getTF_OutputDebugString(TF_Output node); 236 237 } // end namespace tensorflow 238 239 #endif // TENSORFLOW_C_C_API_INTERNAL_H_ 240