• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Classes for allocating XLA literals in device memory and managing handles
17 // that refer to them.
18 
19 #include "tensorflow/compiler/xrt/kernels/xrt_state_ops.h"
20 
21 #include <memory>
22 #include <string>
23 
24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
25 #include "tensorflow/compiler/xla/client/local_client.h"
26 #include "tensorflow/compiler/xrt/xrt_metrics.h"
27 
28 namespace tensorflow {
29 namespace {
30 
31 class XRTMetricsCollectOp : public OpKernel {
32  public:
XRTMetricsCollectOp(OpKernelConstruction * ctx)33   explicit XRTMetricsCollectOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
34 
Compute(OpKernelContext * ctx)35   void Compute(OpKernelContext* ctx) override {
36     VLOG(1) << "XRTMetricsCollectOp::Compute";
37 
38     const Tensor& metrics_proto = ctx->input(0);
39     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(metrics_proto.shape()),
40                 errors::Internal("request input should be a string scalar"));
41     xrt::XRTMetricsCollect metrics;
42     OP_REQUIRES(ctx,
43                 ParseFromTString(metrics_proto.scalar<tstring>()(), &metrics),
44                 errors::InvalidArgument(
45                     "Unable to parse request input to XRTMetricsCollect"));
46 
47     xla::StatusOr<xrt::MetricsReport> collected_metrics_or =
48         CollectMetrics(metrics);
49     OP_REQUIRES_OK(ctx, collected_metrics_or.status());
50     xrt::MetricsReport collected_metrics =
51         collected_metrics_or.ConsumeValueOrDie();
52     Tensor output(DT_STRING, TensorShape({}));
53     output.scalar<tstring>()() = collected_metrics.SerializeAsString();
54     ctx->set_output(0, output);
55   }
56 };
57 
58 }  // namespace
59 
60 REGISTER_KERNEL_BUILDER(Name("XRTAllocate")
61                             .Device(DEVICE_XLA_GPU)
62                             .HostMemory("allocation")
63                             .HostMemory("handle"),
64                         XRTAllocateOp<XRTGenericDeviceAccessor>);
65 REGISTER_KERNEL_BUILDER(Name("XRTAllocate")
66                             .Device(DEVICE_XLA_CPU)
67                             .HostMemory("allocation")
68                             .HostMemory("handle"),
69                         XRTAllocateOp<XRTGenericDeviceAccessor>);
70 
71 REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized")
72                             .Device(DEVICE_XLA_GPU)
73                             .HostMemory("handle"),
74                         XRTAllocateUninitializedOp<XRTGenericDeviceAccessor>);
75 REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized")
76                             .Device(DEVICE_XLA_CPU)
77                             .HostMemory("handle"),
78                         XRTAllocateUninitializedOp<XRTGenericDeviceAccessor>);
79 
80 REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor")
81                             .Device(DEVICE_XLA_GPU)
82                             .HostMemory("inputs")
83                             .HostMemory("handle"),
84                         XRTAllocateFromTensorOp<XRTGenericDeviceAccessor>);
85 REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor")
86                             .Device(DEVICE_XLA_CPU)
87                             .HostMemory("inputs")
88                             .HostMemory("handle"),
89                         XRTAllocateFromTensorOp<XRTGenericDeviceAccessor>);
90 
91 REGISTER_KERNEL_BUILDER(Name("XRTSubTuple")
92                             .Device(DEVICE_XLA_GPU)
93                             .HostMemory("base_handle")
94                             .HostMemory("shape_index")
95                             .HostMemory("output_handle"),
96                         XRTSubTupleOp<false, XRTGenericDeviceAccessor>);
97 REGISTER_KERNEL_BUILDER(Name("XRTSubTuple")
98                             .Device(DEVICE_XLA_CPU)
99                             .HostMemory("base_handle")
100                             .HostMemory("shape_index")
101                             .HostMemory("output_handle"),
102                         XRTSubTupleOp<false, XRTGenericDeviceAccessor>);
103 
104 REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease")
105                             .Device(DEVICE_XLA_GPU)
106                             .HostMemory("base_handle")
107                             .HostMemory("shape_index")
108                             .HostMemory("output_handle"),
109                         XRTSubTupleOp<true, XRTGenericDeviceAccessor>);
110 REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease")
111                             .Device(DEVICE_XLA_CPU)
112                             .HostMemory("base_handle")
113                             .HostMemory("shape_index")
114                             .HostMemory("output_handle"),
115                         XRTSubTupleOp<true, XRTGenericDeviceAccessor>);
116 
117 REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple")
118                             .Device(DEVICE_XLA_GPU)
119                             .HostMemory("tuple_description")
120                             .HostMemory("input_handles")
121                             .HostMemory("output_handle"),
122                         XRTMakeTupleOp<XRTGenericDeviceAccessor>);
123 REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple")
124                             .Device(DEVICE_XLA_CPU)
125                             .HostMemory("tuple_description")
126                             .HostMemory("input_handles")
127                             .HostMemory("output_handle"),
128                         XRTMakeTupleOp<XRTGenericDeviceAccessor>);
129 
130 REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral")
131                             .Device(DEVICE_XLA_GPU)
132                             .HostMemory("handle")
133                             .HostMemory("literal"),
134                         XRTReadLiteralOp<false, XRTGenericDeviceAccessor>);
135 REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral")
136                             .Device(DEVICE_XLA_CPU)
137                             .HostMemory("handle")
138                             .HostMemory("literal"),
139                         XRTReadLiteralOp<false, XRTGenericDeviceAccessor>);
140 
141 REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral")
142                             .Device(DEVICE_XLA_GPU)
143                             .HostMemory("handle")
144                             .HostMemory("literal")
145                             .HostMemory("output_handle"),
146                         XRTWriteLiteralOp<XRTGenericDeviceAccessor>);
147 REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral")
148                             .Device(DEVICE_XLA_CPU)
149                             .HostMemory("handle")
150                             .HostMemory("literal")
151                             .HostMemory("output_handle"),
152                         XRTWriteLiteralOp<XRTGenericDeviceAccessor>);
153 
154 REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease")
155                             .Device(DEVICE_XLA_GPU)
156                             .HostMemory("handle")
157                             .HostMemory("literal"),
158                         XRTReadLiteralOp<true, XRTGenericDeviceAccessor>);
159 REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease")
160                             .Device(DEVICE_XLA_CPU)
161                             .HostMemory("handle")
162                             .HostMemory("literal"),
163                         XRTReadLiteralOp<true, XRTGenericDeviceAccessor>);
164 
165 REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor")
166                             .Device(DEVICE_XLA_GPU)
167                             .HostMemory("handles")
168                             .HostMemory("tensors"),
169                         XRTReadToTensorOp<XRTGenericDeviceAccessor>);
170 REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor")
171                             .Device(DEVICE_XLA_CPU)
172                             .HostMemory("handles")
173                             .HostMemory("tensors"),
174                         XRTReadToTensorOp<XRTGenericDeviceAccessor>);
175 
176 REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle")
177                             .Device(DEVICE_XLA_GPU)
178                             .HostMemory("handle"),
179                         XRTReleaseAllocationOp<XRTGenericDeviceAccessor>);
180 REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle")
181                             .Device(DEVICE_XLA_CPU)
182                             .HostMemory("handle"),
183                         XRTReleaseAllocationOp<XRTGenericDeviceAccessor>);
184 
185 REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_GPU),
186                         XRTReleaseAllAllocationsOp<XRTGenericDeviceAccessor>);
187 REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_CPU),
188                         XRTReleaseAllAllocationsOp<XRTGenericDeviceAccessor>);
189 
190 REGISTER_KERNEL_BUILDER(Name("XRTCompactAllocations").Device(DEVICE_XLA_GPU),
191                         XRTCompactAllocationsOp<XRTGenericDeviceAccessor>);
192 REGISTER_KERNEL_BUILDER(Name("XRTCompactAllocations").Device(DEVICE_XLA_CPU),
193                         XRTCompactAllocationsOp<XRTGenericDeviceAccessor>);
194 
195 REGISTER_KERNEL_BUILDER(Name("XRTMetricsCollect").Device(DEVICE_CPU),
196                         XRTMetricsCollectOp);
197 
198 REGISTER_KERNEL_BUILDER(Name("XRTMemoryInfo").Device(DEVICE_XLA_GPU),
199                         XRTMemoryInfoOp<XRTGenericDeviceAccessor>);
200 REGISTER_KERNEL_BUILDER(Name("XRTMemoryInfo").Device(DEVICE_XLA_CPU),
201                         XRTMemoryInfoOp<XRTGenericDeviceAccessor>);
202 
203 }  // namespace tensorflow
204