1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ 16 #define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ 17 18 #include <algorithm> 19 #include <cstddef> 20 #include <map> 21 #include <memory> 22 #include <queue> 23 #include <string> 24 #include <vector> 25 26 #include "tensorflow/c/c_api.h" 27 #include "tensorflow/c/c_api_internal.h" 28 #include "tensorflow/c/eager/c_api.h" 29 #include "tensorflow/c/eager/c_api_experimental.h" 30 #include "tensorflow/c/eager/tensor_handle_interface.h" 31 #include "tensorflow/core/common_runtime/device_factory.h" 32 #include "tensorflow/core/common_runtime/eager/attr_builder.h" 33 #include "tensorflow/core/common_runtime/eager/context.h" 34 #include "tensorflow/core/common_runtime/eager/eager_executor.h" 35 #include "tensorflow/core/common_runtime/eager/eager_operation.h" 36 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" 37 #include "tensorflow/core/common_runtime/eager/tensor_handle.h" 38 #include "tensorflow/core/common_runtime/function.h" 39 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 40 #include "tensorflow/core/framework/cancellation.h" 41 #include "tensorflow/core/framework/rendezvous.h" 42 #include "tensorflow/core/lib/core/errors.h" 43 #include "tensorflow/core/lib/core/stringpiece.h" 44 #include "tensorflow/core/lib/gtl/inlined_vector.h" 45 #include "tensorflow/core/lib/gtl/map_util.h" 46 #include "tensorflow/core/lib/monitoring/counter.h" 47 #include "tensorflow/core/lib/monitoring/gauge.h" 48 #include "tensorflow/core/lib/monitoring/sampler.h" 49 #include "tensorflow/core/platform/mutex.h" 50 #include "tensorflow/core/platform/thread_annotations.h" 51 #include "tensorflow/core/profiler/lib/profiler_session.h" 52 #include "tensorflow/core/public/version.h" 53 54 struct TFE_ContextOptions { 55 TF_SessionOptions session_options; 56 // true if async execution is enabled. 57 bool async = false; 58 TFE_ContextDevicePlacementPolicy device_placement_policy{ 59 TFE_DEVICE_PLACEMENT_SILENT}; 60 TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE}; 61 // If true, lazily copy the remote inputs of a function to the target devices. 62 bool lazy_remote_inputs_copy = true; 63 }; 64 65 struct TFE_Context { 66 tensorflow::EagerContext* context; 67 }; 68 69 struct TFE_TensorHandle { CreateLocalHandleTFE_TensorHandle70 static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t, 71 TF_Status* s) { 72 tensorflow::TensorHandle* handle; 73 s->status = tensorflow::TensorHandle::CreateLocalHandle(t, &handle); 74 if (!s->status.ok()) { 75 return nullptr; 76 } 77 return new TFE_TensorHandle{ 78 std::make_unique<tensorflow::TensorHandleInterface>(handle)}; 79 } 80 81 std::unique_ptr<AbstractTensorHandleInterface> handle; 82 }; 83 84 struct TFE_TensorDebugInfo { TFE_TensorDebugInfoTFE_TensorDebugInfo85 explicit TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims) 86 : dev_dims(dims) {} 87 88 // Fully-padded, minor-to-major. 89 std::vector<tensorflow::int64> dev_dims; 90 }; 91 92 struct TFE_Op { 93 tensorflow::EagerOperation operation; 94 }; 95 96 struct TFE_Profiler { TFE_ProfilerTFE_Profiler97 explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); } 98 99 std::unique_ptr<tensorflow::ProfilerSession> profiler; 100 }; 101 102 struct TFE_MonitoringCounterCell { 103 tensorflow::monitoring::CounterCell cell; 104 }; 105 106 template <int NumLabels> 107 struct TFE_MonitoringCounter { 108 template <typename... LabelDesc> TFE_MonitoringCounterTFE_MonitoringCounter109 TFE_MonitoringCounter(const char* name, const char* description, 110 LabelDesc&&... label) { 111 counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New( 112 name, description, label...)); 113 } 114 115 std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter; 116 }; 117 118 struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> { 119 using TFE_MonitoringCounter::TFE_MonitoringCounter; 120 }; 121 struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> { 122 using TFE_MonitoringCounter::TFE_MonitoringCounter; 123 }; 124 struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> { 125 using TFE_MonitoringCounter::TFE_MonitoringCounter; 126 }; 127 128 struct TFE_MonitoringIntGaugeCell { 129 tensorflow::monitoring::GaugeCell<tensorflow::int64> cell; 130 }; 131 struct TFE_MonitoringStringGaugeCell { 132 tensorflow::monitoring::GaugeCell<tensorflow::string> cell; 133 }; 134 struct TFE_MonitoringBoolGaugeCell { 135 tensorflow::monitoring::GaugeCell<bool> cell; 136 }; 137 138 template <typename ValueType, int NumLabels> 139 struct TFE_MonitoringGauge { 140 template <typename... LabelDesc> TFE_MonitoringGaugeTFE_MonitoringGauge141 TFE_MonitoringGauge(const char* name, const char* description, 142 LabelDesc&&... label) { 143 gauge = absl::WrapUnique( 144 tensorflow::monitoring::Gauge<ValueType, NumLabels>::New( 145 name, description, label...)); 146 } 147 148 std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge; 149 }; 150 151 struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> { 152 using TFE_MonitoringGauge::TFE_MonitoringGauge; 153 }; 154 struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> { 155 using TFE_MonitoringGauge::TFE_MonitoringGauge; 156 }; 157 struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> { 158 using TFE_MonitoringGauge::TFE_MonitoringGauge; 159 }; 160 161 struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> { 162 using TFE_MonitoringGauge::TFE_MonitoringGauge; 163 }; 164 struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> { 165 using TFE_MonitoringGauge::TFE_MonitoringGauge; 166 }; 167 struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> { 168 using TFE_MonitoringGauge::TFE_MonitoringGauge; 169 }; 170 171 struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> { 172 using TFE_MonitoringGauge::TFE_MonitoringGauge; 173 }; 174 struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> { 175 using TFE_MonitoringGauge::TFE_MonitoringGauge; 176 }; 177 struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> { 178 using TFE_MonitoringGauge::TFE_MonitoringGauge; 179 }; 180 181 struct TFE_MonitoringBuckets { TFE_MonitoringBucketsTFE_MonitoringBuckets182 explicit TFE_MonitoringBuckets( 183 std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)> 184 fn) { 185 create_buckets = fn; 186 } 187 188 std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)> 189 create_buckets; 190 }; 191 192 struct TFE_MonitoringSamplerCell { 193 tensorflow::monitoring::SamplerCell cell; 194 }; 195 196 template <int NumLabels> 197 struct TFE_MonitoringSampler { 198 template <typename... LabelDesc> TFE_MonitoringSamplerTFE_MonitoringSampler199 TFE_MonitoringSampler( 200 const char* name, 201 std::unique_ptr<tensorflow::monitoring::Buckets> buckets, 202 const char* description, LabelDesc&&... label) { 203 sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New( 204 {name, description, label...}, std::move(buckets))); 205 } 206 207 std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler; 208 }; 209 210 struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> { 211 using TFE_MonitoringSampler::TFE_MonitoringSampler; 212 }; 213 struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> { 214 using TFE_MonitoringSampler::TFE_MonitoringSampler; 215 }; 216 struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> { 217 using TFE_MonitoringSampler::TFE_MonitoringSampler; 218 }; 219 220 namespace tensorflow { 221 // Set an AttrValue on the op. Doesn't handle the list types. 222 void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, 223 const tensorflow::AttrValue& default_value, 224 const char* attr_name, TF_Status* status); 225 } // namespace tensorflow 226 227 struct TFE_CancellationManager { 228 tensorflow::CancellationManager cancellation_manager; 229 }; 230 231 struct TFE_Executor { TFE_ExecutorTFE_Executor232 explicit TFE_Executor(bool async) 233 : owned_executor(new tensorflow::EagerExecutor(async)) {} 234 TFE_ExecutorTFE_Executor235 explicit TFE_Executor(tensorflow::EagerExecutor* executor) 236 : owned_executor(nullptr), unowned_executor(executor) {} 237 executorTFE_Executor238 tensorflow::EagerExecutor* executor() { 239 return owned_executor == nullptr ? unowned_executor : owned_executor.get(); 240 } 241 242 std::unique_ptr<tensorflow::EagerExecutor> owned_executor; 243 tensorflow::EagerExecutor* unowned_executor; 244 }; 245 246 #endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ 247