• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A simple logging device to test custom device registration.
17 #include <memory>
18 
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/c/eager/c_api.h"
21 #include "tensorflow/c/eager/c_api_experimental.h"
22 #include "tensorflow/c/eager/c_api_test_util.h"
23 #include "tensorflow/c/tf_status.h"
24 #include "tensorflow/core/lib/gtl/cleanup.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 namespace {
28 
29 struct LoggingDevice {
30   tensorflow::string device_name;
31   tensorflow::string underlying_device;
32   // Set to true whenever a TensorHandle is copied onto the device
33   bool* arrived_flag;
34   // Set to true whenever an operation is executed
35   bool* executed_flag;
36   // If true, only explicit op placements are accepted. If false, uses
37   // type-based dispatch.
38   bool strict_scope_placement;
39 };
40 
41 struct LoggedTensor {
42   TFE_TensorHandle* tensor;
43   LoggedTensor() = delete;
LoggedTensor__anonb3a4e9dc0111::LoggedTensor44   explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor__anonb3a4e9dc0111::LoggedTensor45   ~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
46 };
47 
LoggedTensorDeallocator(void * data,size_t len,void * arg)48 void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
49   delete reinterpret_cast<LoggedTensor*>(data);
50 }
51 
MakeLoggedTensorHandle(TFE_Context * context,const tensorflow::string & logging_device_name,std::unique_ptr<LoggedTensor> t,TF_Status * status)52 TFE_TensorHandle* MakeLoggedTensorHandle(
53     TFE_Context* context, const tensorflow::string& logging_device_name,
54     std::unique_ptr<LoggedTensor> t, TF_Status* status) {
55   std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
56   if (TF_GetCode(status) != TF_OK) return nullptr;
57   for (int i = 0; i < shape.size(); ++i) {
58     shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
59     if (TF_GetCode(status) != TF_OK) return nullptr;
60   }
61   auto dtype = TFE_TensorHandleDataType(t->tensor);
62   return TFE_NewTensorHandleFromDeviceMemory(
63       context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
64       t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
65 }
66 
CopyToLoggingDevice(TFE_Context * context,TFE_TensorHandle * tensor,TF_Status * status,void * device_info)67 TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
68                                       TFE_TensorHandle* tensor,
69                                       TF_Status* status, void* device_info) {
70   LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
71   TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
72       tensor, context, dev->underlying_device.c_str(), status);
73   if (TF_GetCode(status) != TF_OK) return nullptr;
74   auto dst = std::make_unique<LoggedTensor>(t);
75   *(dev->arrived_flag) = true;
76   return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
77                                 status);
78 }
79 
CopyTensorFromLoggingDevice(TFE_Context * context,TFE_TensorHandle * tensor,const char * target_device_name,TF_Status * status,void * device_info)80 TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
81                                               TFE_TensorHandle* tensor,
82                                               const char* target_device_name,
83                                               TF_Status* status,
84                                               void* device_info) {
85   TF_SetStatus(status, TF_INTERNAL,
86                "Trying to copy a tensor out of a logging device.");
87   return nullptr;
88 }
89 
LoggingDeviceExecute(const TFE_Op * original_op,int * num_outputs,TFE_TensorHandle ** outputs,TF_Status * s,void * device_info)90 void LoggingDeviceExecute(const TFE_Op* original_op, int* num_outputs,
91                           TFE_TensorHandle** outputs, TF_Status* s,
92                           void* device_info) {
93   const char* requested_placement = TFE_OpGetDevice(original_op, s);
94   if (TF_GetCode(s) != TF_OK) return;
95 
96   LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
97   if (dev->strict_scope_placement && *requested_placement == '\0') {
98     TF_SetStatus(s, TF_INTERNAL,
99                  "Ops must be placed on the device explicitly, or their inputs "
100                  "first copied to other devices.");
101     return;
102   }
103   TFE_Context* context = TFE_OpGetContext(original_op, s);
104   if (TF_GetCode(s) != TF_OK) return;
105   const char* operation_name = TFE_OpGetName(original_op, s);
106   if (TF_GetCode(s) != TF_OK) return;
107   const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
108 
109   TFE_Op* op(TFE_NewOp(context, operation_name, s));
110   if (TF_GetCode(s) != TF_OK) return;
111   TFE_OpAddAttrs(op, attributes);
112   TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
113   if (TF_GetCode(s) != TF_OK) return;
114   int num_inputs = TFE_OpGetFlatInputCount(original_op, s);
115   if (TF_GetCode(s) != TF_OK) return;
116   for (int j = 0; j < num_inputs; ++j) {
117     TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, j, s);
118     if (TF_GetCode(s) != TF_OK) return;
119     const char* input_device = TFE_TensorHandleDeviceName(input, s);
120     if (TF_GetCode(s) != TF_OK) return;
121     if (dev->device_name == input_device) {
122       LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
123           TFE_TensorHandleDevicePointer(input, s));
124       if (TF_GetCode(s) != TF_OK) return;
125       TFE_OpAddInput(op, t->tensor, s);
126     } else {
127       TFE_OpAddInput(op, input, s);
128     }
129     if (TF_GetCode(s) != TF_OK) return;
130   }
131   std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
132   TFE_Execute(op, op_outputs.data(), num_outputs, s);
133   TFE_DeleteOp(op);
134   if (TF_GetCode(s) != TF_OK) return;
135   std::vector<TFE_TensorHandle*> unwrapped_outputs;
136   unwrapped_outputs.reserve(op_outputs.size());
137   for (auto* handle : op_outputs) {
138     unwrapped_outputs.push_back(handle);
139   }
140   for (int i = 0; i < *num_outputs; ++i) {
141     auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
142     outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
143                                         std::move(logged_tensor), s);
144   }
145   *(dev->executed_flag) = true;
146 }
147 
DeleteLoggingDevice(void * device_info)148 void DeleteLoggingDevice(void* device_info) {
149   delete reinterpret_cast<LoggingDevice*>(device_info);
150 }
151 
152 }  // namespace
153 
RegisterLoggingDevice(TFE_Context * context,const char * name,bool strict_scope_placement,bool * arrived_flag,bool * executed_flag,TF_Status * status)154 void RegisterLoggingDevice(TFE_Context* context, const char* name,
155                            bool strict_scope_placement, bool* arrived_flag,
156                            bool* executed_flag, TF_Status* status) {
157   TFE_CustomDevice custom_device;
158   custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
159   custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
160   custom_device.delete_device = &DeleteLoggingDevice;
161   custom_device.execute = &LoggingDeviceExecute;
162   LoggingDevice* device = new LoggingDevice;
163   device->arrived_flag = arrived_flag;
164   device->executed_flag = executed_flag;
165   device->device_name = name;
166   device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
167   device->strict_scope_placement = strict_scope_placement;
168   TFE_RegisterCustomDevice(context, custom_device, name, device, status);
169 }
170 
UnpackTensorHandle(TFE_TensorHandle * logged_tensor_handle,TF_Status * status)171 TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
172                                      TF_Status* status) {
173   return reinterpret_cast<LoggedTensor*>(
174              TFE_TensorHandleDevicePointer(logged_tensor_handle, status))
175       ->tensor;
176 }
177 
AllocateLoggingDevice(const char * name,bool * arrived_flag,bool * executed_flag,TFE_CustomDevice ** device,void ** device_info)178 void AllocateLoggingDevice(const char* name, bool* arrived_flag,
179                            bool* executed_flag, TFE_CustomDevice** device,
180                            void** device_info) {
181   TFE_CustomDevice* custom_device = new TFE_CustomDevice;
182   custom_device->copy_tensor_to_device = &CopyToLoggingDevice;
183   custom_device->copy_tensor_from_device = &CopyTensorFromLoggingDevice;
184   custom_device->delete_device = &DeleteLoggingDevice;
185   custom_device->execute = &LoggingDeviceExecute;
186   *device = custom_device;
187   LoggingDevice* logging_device = new LoggingDevice;
188   logging_device->arrived_flag = arrived_flag;
189   logging_device->executed_flag = executed_flag;
190   logging_device->device_name = name;
191   logging_device->underlying_device =
192       "/job:localhost/replica:0/task:0/device:CPU:0";
193   logging_device->strict_scope_placement = true;
194   *device_info = reinterpret_cast<void*>(logging_device);
195 }
196