• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/virtual_device.h"
17 
18 #include "tensorflow/core/framework/tensor.pb.h"
19 
20 namespace tensorflow {
21 namespace {
22 
23 class VirtualDeviceContext : public DeviceContext {
24  public:
25   VirtualDeviceContext() = default;
26 
27   void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
28                              Tensor* device_tensor, StatusCallback done,
29                              bool sync_dst_compute) const override;
30   void CopyDeviceTensorToCPU(const Tensor* device_tensor,
31                              StringPiece tensor_name, Device* device,
32                              Tensor* cpu_tensor, StatusCallback done) override;
33   void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
34                               Tensor* output_tensor,
35                               StatusCallback done) const override;
36 };
37 
CopyCPUTensorToDevice(const Tensor * cpu_tensor,Device * device,Tensor * device_tensor,StatusCallback done,bool sync_dst_compute) const38 void VirtualDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
39                                                  Device* device,
40                                                  Tensor* device_tensor,
41                                                  StatusCallback done,
42                                                  bool sync_dst_compute) const {
43   *device_tensor = *cpu_tensor;
44   done(Status::OK());
45 }
46 
CopyDeviceTensorToCPU(const Tensor * device_tensor,StringPiece tensor_name,Device * device,Tensor * cpu_tensor,StatusCallback done)47 void VirtualDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
48                                                  StringPiece tensor_name,
49                                                  Device* device,
50                                                  Tensor* cpu_tensor,
51                                                  StatusCallback done) {
52   *cpu_tensor = *device_tensor;
53   done(Status::OK());
54 }
55 
CopyTensorInSameDevice(const Tensor * input_tensor,Device * device,Tensor * output_tensor,StatusCallback done) const56 void VirtualDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor,
57                                                   Device* device,
58                                                   Tensor* output_tensor,
59                                                   StatusCallback done) const {
60   *output_tensor = *input_tensor;
61   done(Status::OK());
62 }
63 
64 }  // namespace
65 
66 // VirtualDevice
67 
VirtualDevice(Env * env,const DeviceAttributes & device_attributes)68 VirtualDevice::VirtualDevice(Env* env,
69                              const DeviceAttributes& device_attributes)
70     : Device(env, device_attributes) {}
71 
Sync()72 Status VirtualDevice::Sync() { return Status::OK(); }
73 
GetAllocator(AllocatorAttributes attr)74 Allocator* VirtualDevice::GetAllocator(AllocatorAttributes attr) {
75   // Tensors always live on the host.
76   return cpu_allocator();
77 }
78 
MakeTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)79 Status VirtualDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
80                                           const AllocatorAttributes alloc_attrs,
81                                           Tensor* tensor) {
82   Tensor parsed(tensor_proto.dtype());
83   Allocator* allocator = cpu_allocator();
84   if (!parsed.FromProto(allocator, tensor_proto)) {
85     return errors::InvalidArgument("Cannot parse tensor from proto: ",
86                                    tensor_proto.DebugString());
87   }
88   *tensor = parsed;
89   return Status::OK();
90 }
91 
TryGetDeviceContext(DeviceContext ** out_context)92 Status VirtualDevice::TryGetDeviceContext(DeviceContext** out_context) {
93   *out_context = new VirtualDeviceContext;
94   return Status::OK();
95 }
96 
97 }  // namespace tensorflow
98