1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ 16 #define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ 17 18 #include "tensorflow/c/eager/c_api.h" 19 20 #include <algorithm> 21 #include <cstddef> 22 #include <map> 23 #include <memory> 24 #include <queue> 25 #include <string> 26 #include <thread> 27 #include <vector> 28 29 #include "tensorflow/c/c_api.h" 30 #include "tensorflow/c/c_api_internal.h" 31 #include "tensorflow/core/common_runtime/device_factory.h" 32 #include "tensorflow/core/common_runtime/eager/attr_builder.h" 33 #include "tensorflow/core/common_runtime/eager/context.h" 34 #include "tensorflow/core/common_runtime/eager/eager_executor.h" 35 #include "tensorflow/core/common_runtime/eager/eager_operation.h" 36 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" 37 #include "tensorflow/core/common_runtime/eager/tensor_handle.h" 38 #include "tensorflow/core/common_runtime/function.h" 39 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 40 #include "tensorflow/core/distributed_runtime/eager/eager_client.h" 41 #include "tensorflow/core/distributed_runtime/remote_device.h" 42 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" 43 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" 44 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" 45 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" 46 #include "tensorflow/core/distributed_runtime/server_lib.h" 47 #include "tensorflow/core/distributed_runtime/worker_env.h" 48 #include "tensorflow/core/framework/rendezvous.h" 49 #include "tensorflow/core/lib/core/stringpiece.h" 50 #include "tensorflow/core/lib/gtl/inlined_vector.h" 51 #include "tensorflow/core/lib/gtl/map_util.h" 52 #include "tensorflow/core/lib/gtl/stl_util.h" 53 #include "tensorflow/core/platform/mutex.h" 54 #include "tensorflow/core/platform/thread_annotations.h" 55 #include "tensorflow/core/profiler/lib/profiler_session.h" 56 #include "tensorflow/core/public/version.h" 57 58 struct TFE_ContextOptions { 59 TF_SessionOptions session_options; 60 // true if async execution is enabled. 61 bool async = false; 62 TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT}; 63 }; 64 65 struct TFE_Context { TFE_ContextTFE_Context66 TFE_Context(const tensorflow::SessionOptions& opts, 67 TFE_ContextDevicePlacementPolicy default_policy, bool async, 68 const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned, 69 tensorflow::Rendezvous* rendezvous) 70 : context(opts, 71 static_cast<tensorflow::ContextDevicePlacementPolicy>( 72 default_policy), 73 async, device_mgr, device_mgr_owned, rendezvous) {} 74 75 tensorflow::EagerContext context; 76 }; 77 78 struct TFE_TensorHandle { TFE_TensorHandleTFE_TensorHandle79 TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d, 80 tensorflow::Device* op_device) 81 : handle(new tensorflow::TensorHandle(t, d, op_device, nullptr)) {} 82 TFE_TensorHandleTFE_TensorHandle83 TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {} 84 85 tensorflow::TensorHandle* handle; 86 87 // Create a symbolic tensor. TFE_TensorHandleTFE_TensorHandle88 TFE_TensorHandle(TF_Output t, TF_DataType dtype) 89 : handle(new tensorflow::TensorHandle( 90 tensorflow::OutputGraphNode{t.oper, t.index}, 91 static_cast<tensorflow::DataType>(dtype))) {} 92 }; 93 94 struct TFE_TensorDebugInfo { TFE_TensorDebugInfoTFE_TensorDebugInfo95 TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims) 96 : dev_dims(dims) {} 97 98 // Fully-padded, minor-to-major. 99 std::vector<tensorflow::int64> dev_dims; 100 }; 101 102 struct TFE_OpInferenceContext { TFE_OpInferenceContextTFE_OpInferenceContext103 explicit TFE_OpInferenceContext(const tensorflow::OpDef* op_def) 104 : op_def(op_def) {} 105 106 const tensorflow::OpDef* op_def; // op definition from protobuf 107 int input_arg_idx = 0; // arg definition index for the next input to be added 108 tensorflow::gtl::FlatSet<std::string> attrs; // attributes inferred so far 109 }; 110 111 struct TFE_Op { TFE_OpTFE_Op112 TFE_Op(TFE_Context* ctx, const char* op, bool is_function, 113 const tensorflow::AttrTypeMap* t, 114 TFE_OpInferenceContext* inference_ctx) 115 : operation(&ctx->context, op, is_function, t), 116 inference_ctx(inference_ctx) {} 117 118 tensorflow::EagerOperation operation; 119 std::unique_ptr<TFE_OpInferenceContext> inference_ctx; 120 }; 121 122 struct TFE_ProfilerContext { 123 tensorflow::ProfilerContext profiler_context; 124 }; 125 126 struct TFE_Profiler { TFE_ProfilerTFE_Profiler127 TFE_Profiler(TFE_ProfilerContext* ctx) { 128 profiler = tensorflow::ProfilerSession::Create(&ctx->profiler_context); 129 } 130 131 std::unique_ptr<tensorflow::ProfilerSession> profiler; 132 }; 133 134 namespace tensorflow { 135 // Set an AttrValue on the op. Doesn't handle the list types. 136 void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, 137 const tensorflow::AttrValue& default_value, 138 const char* attr_name, TF_Status* status); 139 } // namespace tensorflow 140 141 struct TFE_TraceContext { 142 TF_Graph* const graph; 143 144 unsigned int node_counter = 0; 145 // Each tensor handle will have its ref count incremented when it's added as a 146 // map key, and decremented when this object is destroyed. 147 std::map<tensorflow::TensorHandle*, TF_Output> input_tensor_map; 148 std::vector<std::pair<tensorflow::TensorHandle*, TF_Output>>* input_tensors = 149 nullptr; 150 TFE_TraceContextTFE_TraceContext151 TFE_TraceContext(TF_Graph* graph) : graph(graph) {} 152 ~TFE_TraceContextTFE_TraceContext153 ~TFE_TraceContext() { 154 delete input_tensors; 155 for (auto input : input_tensor_map) { 156 input.first->Unref(); 157 } 158 } 159 }; 160 161 #endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ 162