• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
16 #define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/types/optional.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/c/eager/abstract_context.h"
24 #include "tensorflow/c/eager/immediate_execution_distributed_manager.h"
25 #include "tensorflow/c/eager/immediate_execution_operation.h"
26 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
27 #include "tensorflow/c/tensor_interface.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/function.pb.h"
30 #include "tensorflow/core/framework/numeric_types.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/platform/platform.h"
34 #include "tensorflow/core/platform/status.h"
35 #include "tensorflow/core/platform/tstring.h"
36 #include "tensorflow/core/protobuf/config.pb.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38 
39 namespace tensorflow {
40 class EagerExecutor;
41 class EagerContext;
42 class CustomDevice;
43 class CustomDeviceOpHandler;
44 
45 // LINT.IfChange
46 // Note: Keep in sync with exported copy of enum in eager/c_api.h.
47 enum ContextDevicePlacementPolicy {
48   // Running operations with input tensors on the wrong device will fail.
49   DEVICE_PLACEMENT_EXPLICIT = 0,
50   // Copy the tensor to the right device but log a warning.
51   DEVICE_PLACEMENT_WARN = 1,
52   // Silently copy the tensor, which has a performance cost since the operation
53   // will be blocked till the copy completes. This is the default policy.
54   DEVICE_PLACEMENT_SILENT = 2,
55   // Placement policy which silently copies int32 tensors but not other dtypes.
56   DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
57 };
58 // LINT.ThenChange(//tensorflow/c/eager/c_api.h)
59 
60 // Abstract interface to a context.
61 //
62 // A context is responsible for creating key objects such as Tensors,
63 // TensorHandles & Operations.
64 class ImmediateExecutionContext : public AbstractContext {
65  public:
66   // Optimized scalar creation functions
67   virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0;
68   virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0;
69   virtual AbstractTensorInterface* CreateInt32Scalar(int32 value) = 0;
70   virtual AbstractTensorInterface* CreateFloatScalar(float value) = 0;
71   virtual AbstractTensorInterface* CreateDoubleScalar(double value) = 0;
72   virtual AbstractTensorInterface* CreateHalfScalar(Eigen::half value) = 0;
73   virtual AbstractTensorInterface* CreateStringScalar(tstring value) = 0;
74   virtual AbstractTensorInterface* CreateComplex128Scalar(complex128 value) = 0;
75   virtual AbstractTensorInterface* CreateBoolScalar(bool value) = 0;
76 
77   // Tensor creation functions
78   virtual AbstractTensorInterface* CreateTensor(
79       DataType dtype, absl::Span<const int64> dim_sizes) = 0;
80 
81   typedef void (*MemoryReleaser)(void* data, size_t len, void* arg);
82 
83   // Create a tensor instance from the given data buffer and description.
84   // `memory_releaser` will be called on destruction, and it's responsible for
85   // cleaning up the underlying buffer.
86   virtual AbstractTensorInterface* CreateTensor(
87       DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len,
88       MemoryReleaser memory_releaser, void* memory_releaser_arg) = 0;
89 
90   // Create a handle to wrap and manage a Tensor
91   virtual ImmediateExecutionTensorHandle* CreateLocalHandle(
92       AbstractTensorInterface* t) = 0;
93   // Copy the handle to another device.
94   virtual ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
95       ImmediateExecutionTensorHandle* handle, const char* device_name,
96       Status* status) = 0;
97 
98   // Create an operation to perform op execution
99   ImmediateExecutionOperation* CreateOperation() override = 0;
100 
101   // Returns whether the runtime is backed by TFRT or the legacy TF Eager
102   // Runtime. This is necessary to decouple runtime-dependent
103   // code that is layered on top of the runtime.
104   virtual bool UsesTFRT() = 0;
105 
106   // List attributes of available devices
107   virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
108 
109   // Block until all pending nodes are finished.
110   virtual Status AsyncWait() = 0;
111 
112   // Add a function (serialized FunctionDef protocol buffer) so that it can
113   // be executed as an op. Return error if the function with the same name
114   // already exists.
115   virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
116 
117   // Same as `AddFunctionDef`, but additionally saves the `stack_traces` under
118   // the key of the function definition name (to be retrieved during function
119   // instantiation).
120   virtual Status AddFunctionDefWithStackTraces(
121       const FunctionDef& fdef, const StackTracesMap& stack_traces) = 0;
122 
123   // Find and return a added function by its name.
124   virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;
125 
126   // Return the ParsedName of Host CPU device.
127   virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
128   virtual const string& HostCPUName() const = 0;
129 
130   // Configure soft device placement policy.
131   virtual void SetAllowSoftPlacement(bool enable) = 0;
132 
133   // Configure device placement policy logging.
134   virtual void SetLogDevicePlacement(bool enable) = 0;
135 
136   // Sets the device placement policy for the current thread.
137   virtual void SetThreadLocalDevicePlacementPolicy(
138       ContextDevicePlacementPolicy policy) = 0;
139   // Returns the device placement policy for the current thread.
140   virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
141 
142   // Configure graph collection in RunMetadata.
143   virtual void SetShouldStoreGraphs(bool value) = 0;
144 
145   // Return the collected RunMetadata. This method will transfer the ownership
146   // to the caller.
147   virtual std::unique_ptr<RunMetadata> ExportRunMetadata() = 0;
148 
149   // For LLVM style RTTI.
classof(const AbstractContext * ptr)150   static bool classof(const AbstractContext* ptr) {
151     return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
152   }
153 
154   //===--------------------------------------------------------------------===//
155   // Experimental Custom Device.
156   //===--------------------------------------------------------------------===//
157   virtual CustomDeviceOpHandler& GetCustomDeviceOpHandler() = 0;
158 
159   // Register a custom device. It will return error is the device name is
160   // already registered.
161   // TODO(tfrt-devs): Remove this method. Let caller register it directly into
162   // CustomDeviceOpHandler.
163   virtual Status RegisterCustomDevice(const string& name,
164                                       std::unique_ptr<CustomDevice> device) = 0;
165 
166   //===--------------------------------------------------------------------===//
167   // Following are features in current TF Eager Runtime.
168   // TODO(tfrt-devs): Figure out a way to deprecate following features after
169   // migrated to TFRT.
170   //===--------------------------------------------------------------------===//
171   // Clear pending nodes in thread executors and kernel caches.
172   virtual void ClearCachesAndThreadExecutors() = 0;
173 
174   // Initialize the step resource container for a training step. This is used
175   // in current TF runtime. For tfrt, it is used by fallback op handler.
176   virtual void StartStep() = 0;
177   // Destroy the step resource container for a training step.
178   virtual void EndStep() = 0;
179 
180   // Return the Eager Executor for current thread. Please note that Eager
181   // Executor is only used in current TF but not in TFRT.
182   virtual EagerExecutor& Executor() = 0;
183   // Update the Eager Executor for current thread.
184   virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
185 
186   // Return a list of local tensorflow::Device*.
187   virtual std::vector<tensorflow::Device*> ListLocalTfDevices() = 0;
188 
189   //===--------------------------------------------------------------------===//
190   // Following are helper functions to assist integrating TFRT with current
191   // TF eager runtime.
192   // TODO(b/172877902): These helper functions are currently used to support
193   // PyFuncOp on TFRT, and might be useful for ops that directly use low
194   // level TF APIs. Remove/replace the following functions when TFRT native
195   // ops are implemented.
196   //===--------------------------------------------------------------------===//
197   // Create an abstract tensor handle from tensorflow::Tensor.
198   virtual ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
199       tensorflow::Tensor& t, const char* d_name) = 0;
200 
201   // Convert a TFRT TensorHandle to tensorflow::TensorHandle.
202   virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
203       ImmediateExecutionTensorHandle* handle) = 0;
204 
GetLoggedOpsTestonly()205   virtual std::vector<std::string> GetLoggedOpsTestonly() { return {}; }
206 
207   // Get a list of the names of functions that have been registered.
208   virtual std::vector<string> ListFunctionNames() = 0;
209 
210   //===--------------------------------------------------------------------===//
211   // Distributed runtime related functions.
212   //===--------------------------------------------------------------------===//
213 #if !defined(IS_MOBILE_PLATFORM)
214   // Set a distributed manager that helps set up, update, and check liveness
215   // of member tasks in the cluster.
216   virtual void SetDistributedManager(
217       std::unique_ptr<ImmediateExecutionDistributedManager> distributed) = 0;
218 
219   virtual ImmediateExecutionDistributedManager* GetDistributedManager() = 0;
220 #endif  // !IS_MOBILE_PLATFORM
221 
222  protected:
ImmediateExecutionContext(AbstractContextKind kind)223   explicit ImmediateExecutionContext(AbstractContextKind kind)
224       : AbstractContext(kind) {}
~ImmediateExecutionContext()225   ~ImmediateExecutionContext() override {}
226 };
227 
228 namespace internal {
229 struct ImmediateExecutionContextDeleter {
operatorImmediateExecutionContextDeleter230   void operator()(ImmediateExecutionContext* p) const {
231     if (p != nullptr) {
232       p->Release();
233     }
234   }
235 };
236 }  // namespace internal
237 
238 using ImmediateContextPtr =
239     std::unique_ptr<ImmediateExecutionContext,
240                     internal::ImmediateExecutionContextDeleter>;
241 
242 }  // namespace tensorflow
243 
244 #endif  // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
245