1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Common kernel registrations for XLA devices. 17 18 #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_ 19 #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_ 20 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/resource_mgr.h" 23 #include "tensorflow/core/kernels/constant_op.h" 24 #include "tensorflow/core/kernels/data/generator_dataset_op.h" 25 #include "tensorflow/core/kernels/data/iterator_ops.h" 26 #include "tensorflow/core/kernels/data/optional_ops.h" 27 #include "tensorflow/core/kernels/data/prefetch_dataset_op.h" 28 #include "tensorflow/core/kernels/fifo_queue.h" 29 #include "tensorflow/core/kernels/function_ops.h" 30 #include "tensorflow/core/kernels/identity_op.h" 31 #include "tensorflow/core/kernels/resource_variable_ops.h" 32 #include "tensorflow/core/kernels/shape_ops.h" 33 #include "tensorflow/core/kernels/variable_ops.h" 34 35 namespace tensorflow { 36 37 // Dummy OpKernel, used for kernels assigned to an XLA device that should be 38 // compiled. Should never be called at runtime since such ops should be 39 // rewritten to a XlaLaunch op. If it is called, it means the placer placed an 40 // operator on an XLA device but the compiler did not compile it. 41 class XlaDeviceDummyOp : public OpKernel { 42 public: 43 explicit XlaDeviceDummyOp(OpKernelConstruction* ctx); 44 void Compute(OpKernelContext* ctx) override; 45 }; 46 47 class XlaAssignVariableOp : public OpKernel { 48 public: 49 explicit XlaAssignVariableOp(OpKernelConstruction* c); 50 void Compute(OpKernelContext* context) override; 51 52 private: 53 DataType dtype_; 54 }; 55 56 #define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \ 57 REGISTER_KERNEL_BUILDER(Name("XlaLaunch") \ 58 .Device(DEVICE) \ 59 .HostMemory("constants") \ 60 .HostMemory("resources"), \ 61 KERNEL); 62 63 #define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES) \ 64 REGISTER_KERNEL_BUILDER(Name("_XlaCompile") \ 65 .Device(DEVICE) \ 66 .HostMemory("constants") \ 67 .HostMemory("key") \ 68 .HostMemory("compilation_successful") \ 69 .HostMemory("resources"), \ 70 KERNEL); 71 72 #define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \ 73 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL); 74 75 #define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \ 76 REGISTER_KERNEL_BUILDER( \ 77 Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \ 78 ConstantOp); \ 79 REGISTER_KERNEL_BUILDER( \ 80 Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \ 81 \ 82 REGISTER_KERNEL_BUILDER( \ 83 Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), VarHandleOp); \ 84 REGISTER_KERNEL_BUILDER( \ 85 Name("_VarHandlesOp").Device(DEVICE).HostMemory("resources"), \ 86 ResourceHandlesOp<Var>); \ 87 REGISTER_KERNEL_BUILDER( \ 88 Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \ 89 ReadVariableOp); \ 90 REGISTER_KERNEL_BUILDER( \ 91 Name("_ReadVariablesOp").Device(DEVICE).HostMemory("resources"), \ 92 ReadVariablesOp); \ 93 REGISTER_KERNEL_BUILDER( \ 94 Name("DestroyResourceOp").Device(DEVICE).HostMemory("resource"), \ 95 DestroyResourceOp); \ 96 REGISTER_KERNEL_BUILDER(Name("Shape") \ 97 .Device(DEVICE) \ 98 .HostMemory("output") \ 99 .TypeConstraint<int32>("out_type") \ 100 .TypeConstraint("T", TYPES), \ 101 ShapeOp<int32>); \ 102 REGISTER_KERNEL_BUILDER(Name("Shape") \ 103 .Device(DEVICE) \ 104 .HostMemory("output") \ 105 .TypeConstraint<int64>("out_type") \ 106 .TypeConstraint("T", TYPES), \ 107 ShapeOp<int64>); \ 108 REGISTER_KERNEL_BUILDER(Name("ShapeN") \ 109 .Device(DEVICE) \ 110 .HostMemory("output") \ 111 .TypeConstraint<int32>("out_type") \ 112 .TypeConstraint("T", TYPES), \ 113 ShapeNOp<int32>); \ 114 REGISTER_KERNEL_BUILDER(Name("ShapeN") \ 115 .Device(DEVICE) \ 116 .HostMemory("output") \ 117 .TypeConstraint<int64>("out_type") \ 118 .TypeConstraint("T", TYPES), \ 119 ShapeNOp<int64>); \ 120 REGISTER_KERNEL_BUILDER(Name("VariableShape") \ 121 .Device(DEVICE) \ 122 .TypeConstraint<int32>("out_type") \ 123 .HostMemory("output") \ 124 .HostMemory("input"), \ 125 VariableShapeOp<int32>); \ 126 REGISTER_KERNEL_BUILDER(Name("VariableShape") \ 127 .Device(DEVICE) \ 128 .TypeConstraint<int64>("out_type") \ 129 .HostMemory("output") \ 130 .HostMemory("input"), \ 131 VariableShapeOp<int64>); \ 132 REGISTER_KERNEL_BUILDER(Name("Size") \ 133 .Device(DEVICE) \ 134 .HostMemory("output") \ 135 .TypeConstraint<int32>("out_type") \ 136 .TypeConstraint("T", TYPES), \ 137 SizeOp<int32>); \ 138 REGISTER_KERNEL_BUILDER(Name("Size") \ 139 .Device(DEVICE) \ 140 .HostMemory("output") \ 141 .TypeConstraint<int64>("out_type") \ 142 .TypeConstraint("T", TYPES), \ 143 SizeOp<int64>); \ 144 REGISTER_KERNEL_BUILDER( \ 145 Name("Rank").Device(DEVICE).HostMemory("output").TypeConstraint("T", \ 146 TYPES), \ 147 RankOp); \ 148 REGISTER_KERNEL_BUILDER( \ 149 Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"), \ 150 XlaAssignVariableOp); \ 151 \ 152 REGISTER_KERNEL_BUILDER( \ 153 Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \ 154 \ 155 REGISTER_KERNEL_BUILDER( \ 156 Name(kArgOp).Device(DEVICE).TypeConstraint("T", TYPES), ArgOp); \ 157 REGISTER_KERNEL_BUILDER(Name(kArgOp) \ 158 .Device(DEVICE) \ 159 .HostMemory("output") \ 160 .TypeConstraint<ResourceHandle>("T"), \ 161 ArgOp); \ 162 REGISTER_KERNEL_BUILDER( \ 163 Name(kArgOp).Device(DEVICE).TypeConstraint<Variant>("T"), ArgOp); \ 164 \ 165 REGISTER_KERNEL_BUILDER( \ 166 Name(kRetOp).Device(DEVICE).TypeConstraint("T", TYPES), RetvalOp); \ 167 REGISTER_KERNEL_BUILDER(Name(kRetOp) \ 168 .Device(DEVICE) \ 169 .TypeConstraint<ResourceHandle>("T") \ 170 .HostMemory("input"), \ 171 RetvalOp); \ 172 REGISTER_KERNEL_BUILDER( \ 173 Name(kDeviceRetOp).Device(DEVICE).TypeConstraint<int32>("T"), RetvalOp); \ 174 \ 175 REGISTER_KERNEL_BUILDER( \ 176 Name("RemoteCall").Device(DEVICE).HostMemory("target"), RemoteCallOp); \ 177 \ 178 REGISTER_KERNEL_BUILDER( \ 179 Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"), \ 180 data::GeneratorDatasetOp); \ 181 REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") \ 182 .Device(DEVICE) \ 183 .HostMemory("buffer_size") \ 184 .HostMemory("input_dataset") \ 185 .HostMemory("handle"), \ 186 data::PrefetchDatasetOp); \ 187 \ 188 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE), \ 189 data::IteratorHandleOp); \ 190 REGISTER_KERNEL_BUILDER( \ 191 Name("MakeIterator").Device(DEVICE).HostMemory("dataset"), \ 192 data::MakeIteratorOp); \ 193 REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \ 194 data::AnonymousIteratorHandleOp); \ 195 REGISTER_KERNEL_BUILDER(Name("AnonymousIteratorV2").Device(DEVICE), \ 196 data::AnonymousIteratorHandleOp); \ 197 REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE), \ 198 data::DeleteIteratorOp); \ 199 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ 200 data::IteratorGetNextOp); \ 201 REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \ 202 data::IteratorGetNextAsOptionalOp); \ 203 REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \ 204 data::IteratorGetNextOp); \ 205 REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \ 206 .Device(DEVICE) \ 207 .HostMemory("string_handle"), \ 208 data::IteratorToStringHandleOp); \ 209 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2") \ 210 .Device(DEVICE) \ 211 .HostMemory("string_handle"), \ 212 data::IteratorFromStringHandleOp); \ 213 REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE), \ 214 data::OptionalNoneOp); \ 215 REGISTER_KERNEL_BUILDER(Name("OptionalFromValue").Device(DEVICE), \ 216 data::OptionalFromValueOp); \ 217 REGISTER_KERNEL_BUILDER( \ 218 Name("OptionalHasValue").Device(DEVICE).HostMemory("has_value"), \ 219 data::OptionalHasValueOp); \ 220 REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE), \ 221 data::OptionalGetValueOp); \ 222 REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \ 223 .Device(DEVICE) \ 224 .HostMemory("output") \ 225 .TypeConstraint<tstring>("T"), \ 226 ArgOp); \ 227 REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kRetOp) \ 228 .Device(DEVICE) \ 229 .TypeConstraint<tstring>("T") \ 230 .HostMemory("input"), \ 231 RetvalOp); 232 233 // TODO(b/118881356): currently we do not register the QueueEnqueueMany, 234 // QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read 235 // and write the tensors they access in order to concatenate them into a batch. 236 // We would need either to call out to an XLA computation to perform the 237 // concatenation, or we would need to refactor those kernels so the splitting 238 // or merging is done in a separate operator that can be compiled. 239 240 } // namespace tensorflow 241 242 #endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_ 243