1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Classes for allocating XLA literals in device memory and managing handles 17 // that refer to them. 18 19 #include "tensorflow/compiler/xrt/kernels/xrt_state_ops.h" 20 21 #include <memory> 22 #include <string> 23 24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 25 #include "tensorflow/compiler/xla/client/local_client.h" 26 #include "tensorflow/compiler/xrt/xrt_metrics.h" 27 28 namespace tensorflow { 29 namespace { 30 31 class XRTMetricsCollectOp : public OpKernel { 32 public: XRTMetricsCollectOp(OpKernelConstruction * ctx)33 explicit XRTMetricsCollectOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 34 Compute(OpKernelContext * ctx)35 void Compute(OpKernelContext* ctx) override { 36 VLOG(1) << "XRTMetricsCollectOp::Compute"; 37 38 const Tensor& metrics_proto = ctx->input(0); 39 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(metrics_proto.shape()), 40 errors::Internal("request input should be a string scalar")); 41 xrt::XRTMetricsCollect metrics; 42 OP_REQUIRES(ctx, 43 ParseFromTString(metrics_proto.scalar<tstring>()(), &metrics), 44 errors::InvalidArgument( 45 "Unable to parse request input to XRTMetricsCollect")); 46 47 xla::StatusOr<xrt::MetricsReport> collected_metrics_or = 48 CollectMetrics(metrics); 49 OP_REQUIRES_OK(ctx, collected_metrics_or.status()); 50 xrt::MetricsReport collected_metrics = 51 collected_metrics_or.ConsumeValueOrDie(); 52 Tensor output(DT_STRING, TensorShape({})); 53 output.scalar<tstring>()() = collected_metrics.SerializeAsString(); 54 ctx->set_output(0, output); 55 } 56 }; 57 58 } // namespace 59 60 REGISTER_KERNEL_BUILDER(Name("XRTAllocate") 61 .Device(DEVICE_XLA_GPU) 62 .HostMemory("allocation") 63 .HostMemory("handle"), 64 XRTAllocateOp<XRTGenericDeviceAccessor>); 65 REGISTER_KERNEL_BUILDER(Name("XRTAllocate") 66 .Device(DEVICE_XLA_CPU) 67 .HostMemory("allocation") 68 .HostMemory("handle"), 69 XRTAllocateOp<XRTGenericDeviceAccessor>); 70 71 REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized") 72 .Device(DEVICE_XLA_GPU) 73 .HostMemory("handle"), 74 XRTAllocateUninitializedOp<XRTGenericDeviceAccessor>); 75 REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized") 76 .Device(DEVICE_XLA_CPU) 77 .HostMemory("handle"), 78 XRTAllocateUninitializedOp<XRTGenericDeviceAccessor>); 79 80 REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor") 81 .Device(DEVICE_XLA_GPU) 82 .HostMemory("inputs") 83 .HostMemory("handle"), 84 XRTAllocateFromTensorOp<XRTGenericDeviceAccessor>); 85 REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor") 86 .Device(DEVICE_XLA_CPU) 87 .HostMemory("inputs") 88 .HostMemory("handle"), 89 XRTAllocateFromTensorOp<XRTGenericDeviceAccessor>); 90 91 REGISTER_KERNEL_BUILDER(Name("XRTSubTuple") 92 .Device(DEVICE_XLA_GPU) 93 .HostMemory("base_handle") 94 .HostMemory("shape_index") 95 .HostMemory("output_handle"), 96 XRTSubTupleOp<false, XRTGenericDeviceAccessor>); 97 REGISTER_KERNEL_BUILDER(Name("XRTSubTuple") 98 .Device(DEVICE_XLA_CPU) 99 .HostMemory("base_handle") 100 .HostMemory("shape_index") 101 .HostMemory("output_handle"), 102 XRTSubTupleOp<false, XRTGenericDeviceAccessor>); 103 104 REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease") 105 .Device(DEVICE_XLA_GPU) 106 .HostMemory("base_handle") 107 .HostMemory("shape_index") 108 .HostMemory("output_handle"), 109 XRTSubTupleOp<true, XRTGenericDeviceAccessor>); 110 REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease") 111 .Device(DEVICE_XLA_CPU) 112 .HostMemory("base_handle") 113 .HostMemory("shape_index") 114 .HostMemory("output_handle"), 115 XRTSubTupleOp<true, XRTGenericDeviceAccessor>); 116 117 REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple") 118 .Device(DEVICE_XLA_GPU) 119 .HostMemory("tuple_description") 120 .HostMemory("input_handles") 121 .HostMemory("output_handle"), 122 XRTMakeTupleOp<XRTGenericDeviceAccessor>); 123 REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple") 124 .Device(DEVICE_XLA_CPU) 125 .HostMemory("tuple_description") 126 .HostMemory("input_handles") 127 .HostMemory("output_handle"), 128 XRTMakeTupleOp<XRTGenericDeviceAccessor>); 129 130 REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral") 131 .Device(DEVICE_XLA_GPU) 132 .HostMemory("handle") 133 .HostMemory("literal"), 134 XRTReadLiteralOp<false, XRTGenericDeviceAccessor>); 135 REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral") 136 .Device(DEVICE_XLA_CPU) 137 .HostMemory("handle") 138 .HostMemory("literal"), 139 XRTReadLiteralOp<false, XRTGenericDeviceAccessor>); 140 141 REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") 142 .Device(DEVICE_XLA_GPU) 143 .HostMemory("handle") 144 .HostMemory("literal") 145 .HostMemory("output_handle"), 146 XRTWriteLiteralOp<XRTGenericDeviceAccessor>); 147 REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") 148 .Device(DEVICE_XLA_CPU) 149 .HostMemory("handle") 150 .HostMemory("literal") 151 .HostMemory("output_handle"), 152 XRTWriteLiteralOp<XRTGenericDeviceAccessor>); 153 154 REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") 155 .Device(DEVICE_XLA_GPU) 156 .HostMemory("handle") 157 .HostMemory("literal"), 158 XRTReadLiteralOp<true, XRTGenericDeviceAccessor>); 159 REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") 160 .Device(DEVICE_XLA_CPU) 161 .HostMemory("handle") 162 .HostMemory("literal"), 163 XRTReadLiteralOp<true, XRTGenericDeviceAccessor>); 164 165 REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor") 166 .Device(DEVICE_XLA_GPU) 167 .HostMemory("handles") 168 .HostMemory("tensors"), 169 XRTReadToTensorOp<XRTGenericDeviceAccessor>); 170 REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor") 171 .Device(DEVICE_XLA_CPU) 172 .HostMemory("handles") 173 .HostMemory("tensors"), 174 XRTReadToTensorOp<XRTGenericDeviceAccessor>); 175 176 REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") 177 .Device(DEVICE_XLA_GPU) 178 .HostMemory("handle"), 179 XRTReleaseAllocationOp<XRTGenericDeviceAccessor>); 180 REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") 181 .Device(DEVICE_XLA_CPU) 182 .HostMemory("handle"), 183 XRTReleaseAllocationOp<XRTGenericDeviceAccessor>); 184 185 REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_GPU), 186 XRTReleaseAllAllocationsOp<XRTGenericDeviceAccessor>); 187 REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_CPU), 188 XRTReleaseAllAllocationsOp<XRTGenericDeviceAccessor>); 189 190 REGISTER_KERNEL_BUILDER(Name("XRTCompactAllocations").Device(DEVICE_XLA_GPU), 191 XRTCompactAllocationsOp<XRTGenericDeviceAccessor>); 192 REGISTER_KERNEL_BUILDER(Name("XRTCompactAllocations").Device(DEVICE_XLA_CPU), 193 XRTCompactAllocationsOp<XRTGenericDeviceAccessor>); 194 195 REGISTER_KERNEL_BUILDER(Name("XRTMetricsCollect").Device(DEVICE_CPU), 196 XRTMetricsCollectOp); 197 198 REGISTER_KERNEL_BUILDER(Name("XRTMemoryInfo").Device(DEVICE_XLA_GPU), 199 XRTMemoryInfoOp<XRTGenericDeviceAccessor>); 200 REGISTER_KERNEL_BUILDER(Name("XRTMemoryInfo").Device(DEVICE_XLA_CPU), 201 XRTMemoryInfoOp<XRTGenericDeviceAccessor>); 202 203 } // namespace tensorflow 204