1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_C_C_API_INTERNAL_H_ 17 #define TENSORFLOW_C_C_API_INTERNAL_H_ 18 19 #include "tensorflow/c/c_api.h" 20 21 #include <list> 22 #include <set> 23 #include <string> 24 #include <unordered_map> 25 #include <vector> 26 27 // clang-format off 28 // Required for IS_MOBILE_PLATFORM 29 #include "tensorflow/core/platform/platform.h" 30 // clang-format on 31 32 #include "tensorflow/c/tf_status_internal.h" 33 #include "tensorflow/c/tf_tensor_internal.h" 34 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) 35 #include "tensorflow/core/framework/op_gen_lib.h" 36 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) 37 #include "tensorflow/core/common_runtime/shape_refiner.h" 38 #include "tensorflow/core/framework/tensor.h" 39 #include "tensorflow/core/framework/tensor_shape.h" 40 #include "tensorflow/core/graph/graph.h" 41 #include "tensorflow/core/graph/graph_constructor.h" 42 #include "tensorflow/core/graph/node_builder.h" 43 #include "tensorflow/core/lib/core/status.h" 44 #include "tensorflow/core/platform/mutex.h" 45 #include "tensorflow/core/platform/types.h" 46 #include "tensorflow/core/public/session.h" 47 48 namespace tensorflow { 49 class Device; 50 class DeviceMgr; 51 class ServerInterface; 52 } // namespace tensorflow 53 54 // Internal structures used by the C API. These are likely to change and should 55 // not be depended on. 56 57 struct TF_SessionOptions { 58 tensorflow::SessionOptions options; 59 }; 60 61 struct TF_DeprecatedSession { 62 tensorflow::Session* session; 63 }; 64 65 struct TF_Library { 66 void* lib_handle; 67 TF_Buffer op_list; 68 }; 69 70 struct TF_Graph { 71 TF_Graph(); 72 73 tensorflow::mutex mu; 74 tensorflow::Graph graph GUARDED_BY(mu); 75 76 // Runs shape inference. 77 tensorflow::ShapeRefiner refiner GUARDED_BY(mu); 78 79 // Maps from name of an operation to the Node* in 'graph'. 80 std::unordered_map<tensorflow::string, tensorflow::Node*> name_map 81 GUARDED_BY(mu); 82 83 // The keys of this map are all the active sessions using this graph. Each 84 // value records whether the graph has been mutated since the corresponding 85 // session has been run (this is detected in RecordMutation function). If the 86 // string is empty, no mutation has occurred. Otherwise the string is a 87 // description of the mutation suitable for returning to the user. 88 // 89 // Sessions are added to this map in TF_NewSession, and removed in 90 // TF_DeleteSession. 91 // TF_Graph may only / must be deleted when 92 // sessions.size() == 0 && delete_requested == true 93 // 94 // TODO(b/74949947): mutations currently trigger a warning instead of a bad 95 // status, this should be reverted when possible. 96 tensorflow::gtl::FlatMap<TF_Session*, tensorflow::string> sessions 97 GUARDED_BY(mu); 98 bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph 99 100 // Used to link graphs contained in TF_WhileParams to the parent graph that 101 // will eventually contain the full while loop. 102 TF_Graph* parent; 103 TF_Output* parent_inputs; 104 }; 105 106 struct TF_OperationDescription { TF_OperationDescriptionTF_OperationDescription107 TF_OperationDescription(TF_Graph* g, const char* op_type, 108 const char* node_name) 109 : node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {} 110 111 tensorflow::NodeBuilder node_builder; 112 TF_Graph* graph; 113 std::set<tensorflow::string> colocation_constraints; 114 }; 115 116 struct TF_Operation { 117 tensorflow::Node node; 118 }; 119 120 struct TF_Session { 121 TF_Session(tensorflow::Session* s, TF_Graph* g); 122 123 tensorflow::Session* session; 124 TF_Graph* const graph; 125 126 tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu); 127 int last_num_graph_nodes; 128 129 // If true, TF_SessionRun and similar methods will call 130 // ExtendSessionGraphHelper before running the graph (this is the default 131 // public behavior). Can be set to false if the caller needs to call 132 // ExtendSessionGraphHelper manually. 133 std::atomic<bool> extend_before_run; 134 }; 135 136 struct TF_ImportGraphDefOptions { 137 tensorflow::ImportGraphDefOptions opts; 138 139 // Backing memory for TensorId fields in opts. 140 // TODO(skyewm): it'd be better if ImportGraphDefOptions owned this. 141 std::list<tensorflow::string> tensor_id_data; 142 }; 143 144 struct TF_ImportGraphDefResults { 145 std::vector<TF_Output> return_tensors; 146 std::vector<TF_Operation*> return_nodes; 147 std::vector<const char*> missing_unused_key_names; 148 std::vector<int> missing_unused_key_indexes; 149 150 // Backing memory for missing_unused_key_names values. 151 std::list<tensorflow::string> missing_unused_key_names_data; 152 }; 153 154 struct TF_DeviceList { 155 std::vector<tensorflow::DeviceAttributes> response; 156 }; 157 158 struct TF_Function { 159 tensorflow::FunctionDef fdef; 160 }; 161 162 struct TF_ApiDefMap { TF_ApiDefMapTF_ApiDefMap163 explicit TF_ApiDefMap(const tensorflow::OpList& op_list) 164 : 165 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) 166 api_def_map(op_list), 167 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) 168 update_docs_called(false) { 169 } 170 171 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) 172 tensorflow::ApiDefMap api_def_map GUARDED_BY(lock); 173 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) 174 bool update_docs_called GUARDED_BY(lock); 175 tensorflow::mutex lock; 176 }; 177 178 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) 179 struct TF_Server { 180 TF_Server(std::unique_ptr<tensorflow::ServerInterface> server); 181 182 const tensorflow::string target; 183 std::unique_ptr<tensorflow::ServerInterface> server; 184 }; 185 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) 186 187 namespace tensorflow { 188 189 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); 190 191 TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status); 192 193 Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, 194 TF_Buffer* out); 195 196 // Set the shapes and types of the output's handle. 197 // 198 // The lengths of the arrays pointed to by `shapes`, `ranks`, and `types` must 199 // all be equal to `num_shapes_and_types`. If `ranks[i] != -1`, (i.e., if the 200 // rank is known), then it must be equal to the length of `shapes[i]`; if 201 // `ranks[i] == 1`, then `shapes[i]` may be nullptr. 202 // 203 // TODO(akshayka): Implement a corresponding getter method. 204 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, 205 int num_shapes_and_types, 206 const int64_t** shapes, 207 const int* ranks, 208 const TF_DataType* types, 209 TF_Status* status); 210 211 void RecordMutation(TF_Graph* graph, const TF_Operation& op, 212 const char* mutation_type) 213 EXCLUSIVE_LOCKS_REQUIRED(graph->mu); 214 215 bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) 216 LOCKS_EXCLUDED(session->graph->mu, session->mu); 217 218 std::string getTF_OutputDebugString(TF_Output node); 219 220 } // end namespace tensorflow 221 222 #endif // TENSORFLOW_C_C_API_INTERNAL_H_ 223