• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
16 #define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <map>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <vector>
25 
26 #include "tensorflow/c/c_api.h"
27 #include "tensorflow/c/c_api_internal.h"
28 #include "tensorflow/c/eager/c_api.h"
29 #include "tensorflow/c/eager/c_api_experimental.h"
30 #include "tensorflow/c/eager/tensor_handle_interface.h"
31 #include "tensorflow/core/common_runtime/device_factory.h"
32 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
33 #include "tensorflow/core/common_runtime/eager/context.h"
34 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
35 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
36 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
37 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
38 #include "tensorflow/core/common_runtime/function.h"
39 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
40 #include "tensorflow/core/framework/cancellation.h"
41 #include "tensorflow/core/framework/rendezvous.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/stringpiece.h"
44 #include "tensorflow/core/lib/gtl/inlined_vector.h"
45 #include "tensorflow/core/lib/gtl/map_util.h"
46 #include "tensorflow/core/lib/monitoring/counter.h"
47 #include "tensorflow/core/lib/monitoring/gauge.h"
48 #include "tensorflow/core/lib/monitoring/sampler.h"
49 #include "tensorflow/core/platform/mutex.h"
50 #include "tensorflow/core/platform/thread_annotations.h"
51 #include "tensorflow/core/profiler/lib/profiler_session.h"
52 #include "tensorflow/core/public/version.h"
53 
54 struct TFE_ContextOptions {
55   TF_SessionOptions session_options;
56   // true if async execution is enabled.
57   bool async = false;
58   TFE_ContextDevicePlacementPolicy device_placement_policy{
59       TFE_DEVICE_PLACEMENT_SILENT};
60   TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE};
61   // If true, lazily copy the remote inputs of a function to the target devices.
62   bool lazy_remote_inputs_copy = true;
63 };
64 
65 struct TFE_Context {
66   tensorflow::EagerContext* context;
67 };
68 
69 struct TFE_TensorHandle {
CreateLocalHandleTFE_TensorHandle70   static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t,
71                                              TF_Status* s) {
72     tensorflow::TensorHandle* handle;
73     s->status = tensorflow::TensorHandle::CreateLocalHandle(t, &handle);
74     if (!s->status.ok()) {
75       return nullptr;
76     }
77     return new TFE_TensorHandle{
78         std::make_unique<tensorflow::TensorHandleInterface>(handle)};
79   }
80 
81   std::unique_ptr<AbstractTensorHandleInterface> handle;
82 };
83 
84 struct TFE_TensorDebugInfo {
TFE_TensorDebugInfoTFE_TensorDebugInfo85   explicit TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
86       : dev_dims(dims) {}
87 
88   // Fully-padded, minor-to-major.
89   std::vector<tensorflow::int64> dev_dims;
90 };
91 
92 struct TFE_Op {
93   tensorflow::EagerOperation operation;
94 };
95 
96 struct TFE_Profiler {
TFE_ProfilerTFE_Profiler97   explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }
98 
99   std::unique_ptr<tensorflow::ProfilerSession> profiler;
100 };
101 
102 struct TFE_MonitoringCounterCell {
103   tensorflow::monitoring::CounterCell cell;
104 };
105 
106 template <int NumLabels>
107 struct TFE_MonitoringCounter {
108   template <typename... LabelDesc>
TFE_MonitoringCounterTFE_MonitoringCounter109   TFE_MonitoringCounter(const char* name, const char* description,
110                         LabelDesc&&... label) {
111     counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New(
112         name, description, label...));
113   }
114 
115   std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter;
116 };
117 
118 struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> {
119   using TFE_MonitoringCounter::TFE_MonitoringCounter;
120 };
121 struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> {
122   using TFE_MonitoringCounter::TFE_MonitoringCounter;
123 };
124 struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
125   using TFE_MonitoringCounter::TFE_MonitoringCounter;
126 };
127 
128 struct TFE_MonitoringIntGaugeCell {
129   tensorflow::monitoring::GaugeCell<tensorflow::int64> cell;
130 };
131 struct TFE_MonitoringStringGaugeCell {
132   tensorflow::monitoring::GaugeCell<tensorflow::string> cell;
133 };
134 struct TFE_MonitoringBoolGaugeCell {
135   tensorflow::monitoring::GaugeCell<bool> cell;
136 };
137 
138 template <typename ValueType, int NumLabels>
139 struct TFE_MonitoringGauge {
140   template <typename... LabelDesc>
TFE_MonitoringGaugeTFE_MonitoringGauge141   TFE_MonitoringGauge(const char* name, const char* description,
142                       LabelDesc&&... label) {
143     gauge = absl::WrapUnique(
144         tensorflow::monitoring::Gauge<ValueType, NumLabels>::New(
145             name, description, label...));
146   }
147 
148   std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge;
149 };
150 
151 struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> {
152   using TFE_MonitoringGauge::TFE_MonitoringGauge;
153 };
154 struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> {
155   using TFE_MonitoringGauge::TFE_MonitoringGauge;
156 };
157 struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> {
158   using TFE_MonitoringGauge::TFE_MonitoringGauge;
159 };
160 
161 struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> {
162   using TFE_MonitoringGauge::TFE_MonitoringGauge;
163 };
164 struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> {
165   using TFE_MonitoringGauge::TFE_MonitoringGauge;
166 };
167 struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> {
168   using TFE_MonitoringGauge::TFE_MonitoringGauge;
169 };
170 
171 struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> {
172   using TFE_MonitoringGauge::TFE_MonitoringGauge;
173 };
174 struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> {
175   using TFE_MonitoringGauge::TFE_MonitoringGauge;
176 };
177 struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
178   using TFE_MonitoringGauge::TFE_MonitoringGauge;
179 };
180 
181 struct TFE_MonitoringBuckets {
TFE_MonitoringBucketsTFE_MonitoringBuckets182   explicit TFE_MonitoringBuckets(
183       std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
184           fn) {
185     create_buckets = fn;
186   }
187 
188   std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
189       create_buckets;
190 };
191 
192 struct TFE_MonitoringSamplerCell {
193   tensorflow::monitoring::SamplerCell cell;
194 };
195 
196 template <int NumLabels>
197 struct TFE_MonitoringSampler {
198   template <typename... LabelDesc>
TFE_MonitoringSamplerTFE_MonitoringSampler199   TFE_MonitoringSampler(
200       const char* name,
201       std::unique_ptr<tensorflow::monitoring::Buckets> buckets,
202       const char* description, LabelDesc&&... label) {
203     sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New(
204         {name, description, label...}, std::move(buckets)));
205   }
206 
207   std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler;
208 };
209 
210 struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> {
211   using TFE_MonitoringSampler::TFE_MonitoringSampler;
212 };
213 struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> {
214   using TFE_MonitoringSampler::TFE_MonitoringSampler;
215 };
216 struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> {
217   using TFE_MonitoringSampler::TFE_MonitoringSampler;
218 };
219 
220 namespace tensorflow {
221 // Set an AttrValue on the op. Doesn't handle the list types.
222 void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
223                           const tensorflow::AttrValue& default_value,
224                           const char* attr_name, TF_Status* status);
225 }  // namespace tensorflow
226 
227 struct TFE_CancellationManager {
228   tensorflow::CancellationManager cancellation_manager;
229 };
230 
231 struct TFE_Executor {
TFE_ExecutorTFE_Executor232   explicit TFE_Executor(bool async)
233       : owned_executor(new tensorflow::EagerExecutor(async)) {}
234 
TFE_ExecutorTFE_Executor235   explicit TFE_Executor(tensorflow::EagerExecutor* executor)
236       : owned_executor(nullptr), unowned_executor(executor) {}
237 
executorTFE_Executor238   tensorflow::EagerExecutor* executor() {
239     return owned_executor == nullptr ? unowned_executor : owned_executor.get();
240   }
241 
242   std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
243   tensorflow::EagerExecutor* unowned_executor;
244 };
245 
246 #endif  // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
247