• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A simple logging device to test custom device registration.
17 #include <memory>
18 
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/c/eager/c_api.h"
21 #include "tensorflow/c/eager/c_api_experimental.h"
22 #include "tensorflow/c/eager/c_api_test_util.h"
23 #include "tensorflow/c/tf_status.h"
24 #include "tensorflow/core/lib/gtl/cleanup.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 namespace {
28 
29 struct LoggingDevice {
30   tensorflow::string device_name;
31   tensorflow::string underlying_device;
32   // Set to true whenever a TensorHandle is copied onto the device
33   bool* arrived_flag;
34   // Set to true whenever an operation is executed
35   bool* executed_flag;
36   // If true, only explicit op placements are accepted. If false, uses
37   // type-based dispatch.
38   bool strict_scope_placement;
39 };
40 
41 struct LoggedTensor {
42   TFE_TensorHandle* tensor;
43   LoggedTensor() = delete;
LoggedTensor__anon239fa2ab0111::LoggedTensor44   explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor__anon239fa2ab0111::LoggedTensor45   ~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
46 };
47 
LoggedTensorDim(void * data,int dim_index,TF_Status * status)48 int64_t LoggedTensorDim(void* data, int dim_index, TF_Status* status) {
49   return TFE_TensorHandleDim(reinterpret_cast<LoggedTensor*>(data)->tensor,
50                              dim_index, status);
51 }
52 
LoggedTensorNumDims(void * data,TF_Status * status)53 int LoggedTensorNumDims(void* data, TF_Status* status) {
54   return TFE_TensorHandleNumDims(reinterpret_cast<LoggedTensor*>(data)->tensor,
55                                  status);
56 }
57 
LoggedTensorDeallocator(void * data)58 void LoggedTensorDeallocator(void* data) {
59   delete reinterpret_cast<LoggedTensor*>(data);
60 }
61 
MakeLoggedTensorHandle(TFE_Context * context,const tensorflow::string & logging_device_name,std::unique_ptr<LoggedTensor> t,TF_Status * status)62 TFE_TensorHandle* MakeLoggedTensorHandle(
63     TFE_Context* context, const tensorflow::string& logging_device_name,
64     std::unique_ptr<LoggedTensor> t, TF_Status* status) {
65   auto dtype = TFE_TensorHandleDataType(t->tensor);
66   TFE_CustomDeviceTensorHandleMethods handle_methods;
67   handle_methods.num_dims = &LoggedTensorNumDims;
68   handle_methods.dim = &LoggedTensorDim;
69   handle_methods.deallocator = &LoggedTensorDeallocator;
70   return TFE_NewCustomDeviceTensorHandle(context, logging_device_name.c_str(),
71                                          dtype, t.release(), handle_methods,
72                                          status);
73 }
74 
CopyToLoggingDevice(TFE_Context * context,TFE_TensorHandle * tensor,TF_Status * status,void * device_info)75 TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
76                                       TFE_TensorHandle* tensor,
77                                       TF_Status* status, void* device_info) {
78   LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
79   TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
80       tensor, context, dev->underlying_device.c_str(), status);
81   if (TF_GetCode(status) != TF_OK) return nullptr;
82   auto dst = std::make_unique<LoggedTensor>(t);
83   *(dev->arrived_flag) = true;
84   return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
85                                 status);
86 }
87 
CopyTensorFromLoggingDevice(TFE_Context * context,TFE_TensorHandle * tensor,const char * target_device_name,TF_Status * status,void * device_info)88 TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
89                                               TFE_TensorHandle* tensor,
90                                               const char* target_device_name,
91                                               TF_Status* status,
92                                               void* device_info) {
93   TF_SetStatus(status, TF_INTERNAL,
94                "Trying to copy a tensor out of a logging device.");
95   return nullptr;
96 }
97 
LoggingDeviceExecute(const TFE_Op * original_op,int * num_outputs,TFE_TensorHandle ** outputs,TF_Status * s,void * device_info)98 void LoggingDeviceExecute(const TFE_Op* original_op, int* num_outputs,
99                           TFE_TensorHandle** outputs, TF_Status* s,
100                           void* device_info) {
101   const char* requested_placement = TFE_OpGetDevice(original_op, s);
102   if (TF_GetCode(s) != TF_OK) return;
103 
104   LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
105   if (dev->strict_scope_placement && *requested_placement == '\0') {
106     TF_SetStatus(s, TF_INTERNAL,
107                  "Ops must be placed on the device explicitly, or their inputs "
108                  "first copied to other devices.");
109     return;
110   }
111   TFE_Context* context = TFE_OpGetContext(original_op, s);
112   if (TF_GetCode(s) != TF_OK) return;
113   const char* operation_name = TFE_OpGetName(original_op, s);
114   if (TF_GetCode(s) != TF_OK) return;
115   const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
116 
117   TFE_Op* op(TFE_NewOp(context, operation_name, s));
118   if (TF_GetCode(s) != TF_OK) return;
119   TFE_OpAddAttrs(op, attributes);
120   TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
121   if (TF_GetCode(s) != TF_OK) return;
122   int num_inputs = TFE_OpGetFlatInputCount(original_op, s);
123   if (TF_GetCode(s) != TF_OK) return;
124   for (int j = 0; j < num_inputs; ++j) {
125     TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, j, s);
126     if (TF_GetCode(s) != TF_OK) return;
127     const char* input_device = TFE_TensorHandleDeviceName(input, s);
128     if (TF_GetCode(s) != TF_OK) return;
129     if (dev->device_name == input_device) {
130       LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
131           TFE_TensorHandleDevicePointer(input, s));
132       if (TF_GetCode(s) != TF_OK) return;
133       TFE_OpAddInput(op, t->tensor, s);
134     } else {
135       TFE_OpAddInput(op, input, s);
136     }
137     if (TF_GetCode(s) != TF_OK) return;
138   }
139   std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
140   TFE_Execute(op, op_outputs.data(), num_outputs, s);
141   TFE_DeleteOp(op);
142   if (TF_GetCode(s) != TF_OK) return;
143   std::vector<TFE_TensorHandle*> unwrapped_outputs;
144   unwrapped_outputs.reserve(op_outputs.size());
145   for (auto* handle : op_outputs) {
146     unwrapped_outputs.push_back(handle);
147   }
148   for (int i = 0; i < *num_outputs; ++i) {
149     auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
150     outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
151                                         std::move(logged_tensor), s);
152   }
153   *(dev->executed_flag) = true;
154 }
155 
DeleteLoggingDevice(void * device_info)156 void DeleteLoggingDevice(void* device_info) {
157   delete reinterpret_cast<LoggingDevice*>(device_info);
158 }
159 
160 }  // namespace
161 
RegisterLoggingDevice(TFE_Context * context,const char * name,bool strict_scope_placement,bool * arrived_flag,bool * executed_flag,TF_Status * status)162 void RegisterLoggingDevice(TFE_Context* context, const char* name,
163                            bool strict_scope_placement, bool* arrived_flag,
164                            bool* executed_flag, TF_Status* status) {
165   TFE_CustomDevice custom_device;
166   custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
167   custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
168   custom_device.delete_device = &DeleteLoggingDevice;
169   custom_device.execute = &LoggingDeviceExecute;
170   LoggingDevice* device = new LoggingDevice;
171   device->arrived_flag = arrived_flag;
172   device->executed_flag = executed_flag;
173   device->device_name = name;
174   device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
175   device->strict_scope_placement = strict_scope_placement;
176   TFE_RegisterCustomDevice(context, custom_device, name, device, status);
177 }
178 
UnpackTensorHandle(TFE_TensorHandle * logged_tensor_handle,TF_Status * status)179 TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
180                                      TF_Status* status) {
181   return reinterpret_cast<LoggedTensor*>(
182              TFE_TensorHandleDevicePointer(logged_tensor_handle, status))
183       ->tensor;
184 }
185 
AllocateLoggingDevice(const char * name,bool * arrived_flag,bool * executed_flag,TFE_CustomDevice ** device,void ** device_info)186 void AllocateLoggingDevice(const char* name, bool* arrived_flag,
187                            bool* executed_flag, TFE_CustomDevice** device,
188                            void** device_info) {
189   TFE_CustomDevice* custom_device = new TFE_CustomDevice;
190   custom_device->copy_tensor_to_device = &CopyToLoggingDevice;
191   custom_device->copy_tensor_from_device = &CopyTensorFromLoggingDevice;
192   custom_device->delete_device = &DeleteLoggingDevice;
193   custom_device->execute = &LoggingDeviceExecute;
194   *device = custom_device;
195   LoggingDevice* logging_device = new LoggingDevice;
196   logging_device->arrived_flag = arrived_flag;
197   logging_device->executed_flag = executed_flag;
198   logging_device->device_name = name;
199   logging_device->underlying_device =
200       "/job:localhost/replica:0/task:0/device:CPU:0";
201   logging_device->strict_scope_placement = true;
202   *device_info = reinterpret_cast<void*>(logging_device);
203 }
204