• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
16 #define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
17 
18 #include "tensorflow/c/eager/c_api.h"
19 
20 #include <algorithm>
21 #include <cstddef>
22 #include <map>
23 #include <memory>
24 #include <queue>
25 #include <string>
26 #include <thread>
27 #include <vector>
28 
29 #include "tensorflow/c/c_api.h"
30 #include "tensorflow/c/c_api_internal.h"
31 #include "tensorflow/core/common_runtime/device_factory.h"
32 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
33 #include "tensorflow/core/common_runtime/eager/context.h"
34 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
35 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
36 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
37 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
38 #include "tensorflow/core/common_runtime/function.h"
39 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
40 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
41 #include "tensorflow/core/distributed_runtime/remote_device.h"
42 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
43 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
44 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
45 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
46 #include "tensorflow/core/distributed_runtime/server_lib.h"
47 #include "tensorflow/core/distributed_runtime/worker_env.h"
48 #include "tensorflow/core/framework/rendezvous.h"
49 #include "tensorflow/core/lib/core/stringpiece.h"
50 #include "tensorflow/core/lib/gtl/inlined_vector.h"
51 #include "tensorflow/core/lib/gtl/map_util.h"
52 #include "tensorflow/core/lib/gtl/stl_util.h"
53 #include "tensorflow/core/platform/mutex.h"
54 #include "tensorflow/core/platform/thread_annotations.h"
55 #include "tensorflow/core/profiler/lib/profiler_session.h"
56 #include "tensorflow/core/public/version.h"
57 
58 struct TFE_ContextOptions {
59   TF_SessionOptions session_options;
60   // true if async execution is enabled.
61   bool async = false;
62   TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT};
63 };
64 
65 struct TFE_Context {
TFE_ContextTFE_Context66   TFE_Context(const tensorflow::SessionOptions& opts,
67               TFE_ContextDevicePlacementPolicy default_policy, bool async,
68               const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
69               tensorflow::Rendezvous* rendezvous)
70       : context(opts,
71                 static_cast<tensorflow::ContextDevicePlacementPolicy>(
72                     default_policy),
73                 async, device_mgr, device_mgr_owned, rendezvous) {}
74 
75   tensorflow::EagerContext context;
76 };
77 
78 struct TFE_TensorHandle {
TFE_TensorHandleTFE_TensorHandle79   TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d,
80                    tensorflow::Device* op_device)
81       : handle(new tensorflow::TensorHandle(t, d, op_device, nullptr)) {}
82 
TFE_TensorHandleTFE_TensorHandle83   TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {}
84 
85   tensorflow::TensorHandle* handle;
86 
87   // Create a symbolic tensor.
TFE_TensorHandleTFE_TensorHandle88   TFE_TensorHandle(TF_Output t, TF_DataType dtype)
89       : handle(new tensorflow::TensorHandle(
90             tensorflow::OutputGraphNode{t.oper, t.index},
91             static_cast<tensorflow::DataType>(dtype))) {}
92 };
93 
94 struct TFE_TensorDebugInfo {
TFE_TensorDebugInfoTFE_TensorDebugInfo95   TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
96       : dev_dims(dims) {}
97 
98   // Fully-padded, minor-to-major.
99   std::vector<tensorflow::int64> dev_dims;
100 };
101 
102 struct TFE_OpInferenceContext {
TFE_OpInferenceContextTFE_OpInferenceContext103   explicit TFE_OpInferenceContext(const tensorflow::OpDef* op_def)
104       : op_def(op_def) {}
105 
106   const tensorflow::OpDef* op_def;  // op definition from protobuf
107   int input_arg_idx = 0;  // arg definition index for the next input to be added
108   tensorflow::gtl::FlatSet<std::string> attrs;  // attributes inferred so far
109 };
110 
111 struct TFE_Op {
TFE_OpTFE_Op112   TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
113          const tensorflow::AttrTypeMap* t,
114          TFE_OpInferenceContext* inference_ctx)
115       : operation(&ctx->context, op, is_function, t),
116         inference_ctx(inference_ctx) {}
117 
118   tensorflow::EagerOperation operation;
119   std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
120 };
121 
122 struct TFE_ProfilerContext {
123   tensorflow::ProfilerContext profiler_context;
124 };
125 
126 struct TFE_Profiler {
TFE_ProfilerTFE_Profiler127   TFE_Profiler(TFE_ProfilerContext* ctx) {
128     profiler = tensorflow::ProfilerSession::Create(&ctx->profiler_context);
129   }
130 
131   std::unique_ptr<tensorflow::ProfilerSession> profiler;
132 };
133 
134 namespace tensorflow {
135 // Set an AttrValue on the op. Doesn't handle the list types.
136 void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
137                           const tensorflow::AttrValue& default_value,
138                           const char* attr_name, TF_Status* status);
139 }  // namespace tensorflow
140 
141 struct TFE_TraceContext {
142   TF_Graph* const graph;
143 
144   unsigned int node_counter = 0;
145   // Each tensor handle will have its ref count incremented when it's added as a
146   // map key, and decremented when this object is destroyed.
147   std::map<tensorflow::TensorHandle*, TF_Output> input_tensor_map;
148   std::vector<std::pair<tensorflow::TensorHandle*, TF_Output>>* input_tensors =
149       nullptr;
150 
TFE_TraceContextTFE_TraceContext151   TFE_TraceContext(TF_Graph* graph) : graph(graph) {}
152 
~TFE_TraceContextTFE_TraceContext153   ~TFE_TraceContext() {
154     delete input_tensors;
155     for (auto input : input_tensor_map) {
156       input.first->Unref();
157     }
158   }
159 };
160 
161 #endif  // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
162