• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_C_C_API_INTERNAL_H_
17 #define TENSORFLOW_C_C_API_INTERNAL_H_
18 
19 #include "tensorflow/c/c_api.h"
20 
21 #include <list>
22 #include <set>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26 
27 // clang-format off
28 // Required for IS_MOBILE_PLATFORM
29 #include "tensorflow/core/platform/platform.h"
30 // clang-format on
31 
32 #include "tensorflow/c/tf_status_internal.h"
33 #include "tensorflow/c/tf_tensor_internal.h"
34 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
35 #include "tensorflow/core/framework/op_gen_lib.h"
36 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
37 #include "tensorflow/core/common_runtime/shape_refiner.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/tensor_shape.h"
40 #include "tensorflow/core/graph/graph.h"
41 #include "tensorflow/core/graph/graph_constructor.h"
42 #include "tensorflow/core/graph/node_builder.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/platform/mutex.h"
45 #include "tensorflow/core/platform/types.h"
46 #include "tensorflow/core/public/session.h"
47 
48 namespace tensorflow {
49 class Device;
50 class DeviceMgr;
51 class ServerInterface;
52 }  // namespace tensorflow
53 
54 // Internal structures used by the C API. These are likely to change and should
55 // not be depended on.
56 
57 struct TF_SessionOptions {
58   tensorflow::SessionOptions options;
59 };
60 
61 struct TF_DeprecatedSession {
62   tensorflow::Session* session;
63 };
64 
65 struct TF_Library {
66   void* lib_handle;
67   TF_Buffer op_list;
68 };
69 
70 struct TF_Graph {
71   TF_Graph();
72 
73   tensorflow::mutex mu;
74   tensorflow::Graph graph GUARDED_BY(mu);
75 
76   // Runs shape inference.
77   tensorflow::ShapeRefiner refiner GUARDED_BY(mu);
78 
79   // Maps from name of an operation to the Node* in 'graph'.
80   std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
81       GUARDED_BY(mu);
82 
83   // The keys of this map are all the active sessions using this graph. Each
84   // value records whether the graph has been mutated since the corresponding
85   // session has been run (this is detected in RecordMutation function). If the
86   // string is empty, no mutation has occurred. Otherwise the string is a
87   // description of the mutation suitable for returning to the user.
88   //
89   // Sessions are added to this map in TF_NewSession, and removed in
90   // TF_DeleteSession.
91   // TF_Graph may only / must be deleted when
92   //   sessions.size() == 0 && delete_requested == true
93   //
94   // TODO(b/74949947): mutations currently trigger a warning instead of a bad
95   // status, this should be reverted when possible.
96   tensorflow::gtl::FlatMap<TF_Session*, tensorflow::string> sessions
97       GUARDED_BY(mu);
98   bool delete_requested GUARDED_BY(mu);  // set true by TF_DeleteGraph
99 
100   // Used to link graphs contained in TF_WhileParams to the parent graph that
101   // will eventually contain the full while loop.
102   TF_Graph* parent;
103   TF_Output* parent_inputs;
104 };
105 
106 struct TF_OperationDescription {
TF_OperationDescriptionTF_OperationDescription107   TF_OperationDescription(TF_Graph* g, const char* op_type,
108                           const char* node_name)
109       : node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {}
110 
111   tensorflow::NodeBuilder node_builder;
112   TF_Graph* graph;
113   std::set<tensorflow::string> colocation_constraints;
114 };
115 
116 struct TF_Operation {
117   tensorflow::Node node;
118 };
119 
120 struct TF_Session {
121   TF_Session(tensorflow::Session* s, TF_Graph* g);
122 
123   tensorflow::Session* session;
124   TF_Graph* const graph;
125 
126   tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu);
127   int last_num_graph_nodes;
128 
129   // If true, TF_SessionRun and similar methods will call
130   // ExtendSessionGraphHelper before running the graph (this is the default
131   // public behavior). Can be set to false if the caller needs to call
132   // ExtendSessionGraphHelper manually.
133   std::atomic<bool> extend_before_run;
134 };
135 
136 struct TF_ImportGraphDefOptions {
137   tensorflow::ImportGraphDefOptions opts;
138 
139   // Backing memory for TensorId fields in opts.
140   // TODO(skyewm): it'd be better if ImportGraphDefOptions owned this.
141   std::list<tensorflow::string> tensor_id_data;
142 };
143 
144 struct TF_ImportGraphDefResults {
145   std::vector<TF_Output> return_tensors;
146   std::vector<TF_Operation*> return_nodes;
147   std::vector<const char*> missing_unused_key_names;
148   std::vector<int> missing_unused_key_indexes;
149 
150   // Backing memory for missing_unused_key_names values.
151   std::list<tensorflow::string> missing_unused_key_names_data;
152 };
153 
154 struct TF_DeviceList {
155   std::vector<tensorflow::DeviceAttributes> response;
156 };
157 
158 struct TF_Function {
159   tensorflow::FunctionDef fdef;
160 };
161 
162 struct TF_ApiDefMap {
TF_ApiDefMapTF_ApiDefMap163   explicit TF_ApiDefMap(const tensorflow::OpList& op_list)
164       :
165 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
166         api_def_map(op_list),
167 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
168         update_docs_called(false) {
169   }
170 
171 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
172   tensorflow::ApiDefMap api_def_map GUARDED_BY(lock);
173 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
174   bool update_docs_called GUARDED_BY(lock);
175   tensorflow::mutex lock;
176 };
177 
178 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
179 struct TF_Server {
180   TF_Server(std::unique_ptr<tensorflow::ServerInterface> server);
181 
182   const tensorflow::string target;
183   std::unique_ptr<tensorflow::ServerInterface> server;
184 };
185 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
186 
187 namespace tensorflow {
188 
189 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
190 
191 TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
192 
193 Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
194                        TF_Buffer* out);
195 
196 // Set the shapes and types of the output's handle.
197 //
198 // The lengths of the arrays pointed to by `shapes`, `ranks`, and `types` must
199 // all be equal to `num_shapes_and_types`. If `ranks[i] != -1`, (i.e., if the
200 // rank is known), then it must be equal to the length of `shapes[i]`; if
201 // `ranks[i] == 1`, then `shapes[i]` may be nullptr.
202 //
203 // TODO(akshayka): Implement a corresponding getter method.
204 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
205                                            int num_shapes_and_types,
206                                            const int64_t** shapes,
207                                            const int* ranks,
208                                            const TF_DataType* types,
209                                            TF_Status* status);
210 
211 void RecordMutation(TF_Graph* graph, const TF_Operation& op,
212                     const char* mutation_type)
213     EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
214 
215 bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status)
216     LOCKS_EXCLUDED(session->graph->mu, session->mu);
217 
218 std::string getTF_OutputDebugString(TF_Output node);
219 
220 }  // end namespace tensorflow
221 
222 #endif  // TENSORFLOW_C_C_API_INTERNAL_H_
223