1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Common kernel registrations for XLA devices. 17 18 #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_ 19 #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_ 20 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/resource_mgr.h" 23 #include "tensorflow/core/kernels/cast_op.h" 24 #include "tensorflow/core/kernels/constant_op.h" 25 #include "tensorflow/core/kernels/control_flow_ops.h" 26 #include "tensorflow/core/kernels/data/generator_dataset_op.h" 27 #include "tensorflow/core/kernels/data/iterator_ops.h" 28 #include "tensorflow/core/kernels/data/optional_ops.h" 29 #include "tensorflow/core/kernels/data/prefetch_dataset_op.h" 30 #include "tensorflow/core/kernels/fifo_queue.h" 31 #include "tensorflow/core/kernels/function_ops.h" 32 #include "tensorflow/core/kernels/host_constant_op.h" 33 #include "tensorflow/core/kernels/identity_n_op.h" 34 #include "tensorflow/core/kernels/identity_op.h" 35 #include "tensorflow/core/kernels/no_op.h" 36 #include "tensorflow/core/kernels/queue_op.h" 37 #include "tensorflow/core/kernels/resource_variable_ops.h" 38 #include "tensorflow/core/kernels/sendrecv_ops.h" 39 #include "tensorflow/core/kernels/shape_ops.h" 40 #include "tensorflow/core/kernels/stack.h" 41 #include "tensorflow/core/kernels/variable_ops.h" 42 43 namespace tensorflow { 44 45 // Dummy OpKernel, used for kernels assigned to an XLA device that should be 46 // compiled. Should never be called at runtime since such ops should be 47 // rewritten to a XlaLaunch op. If it is called, it means the placer placed an 48 // operator on an XLA device but the compiler did not compile it. 49 class XlaDeviceDummyOp : public OpKernel { 50 public: 51 explicit XlaDeviceDummyOp(OpKernelConstruction* ctx); 52 void Compute(OpKernelContext* ctx) override; 53 }; 54 55 class XlaAssignVariableOp : public OpKernel { 56 public: 57 explicit XlaAssignVariableOp(OpKernelConstruction* c); 58 void Compute(OpKernelContext* context) override; 59 60 private: 61 DataType dtype_; 62 }; 63 64 #define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \ 65 REGISTER_KERNEL_BUILDER(Name("XlaLaunch") \ 66 .Device(DEVICE) \ 67 .HostMemory("constants") \ 68 .HostMemory("resources"), \ 69 KERNEL); 70 71 #define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES) \ 72 REGISTER_KERNEL_BUILDER(Name("_XlaCompile") \ 73 .Device(DEVICE) \ 74 .HostMemory("constants") \ 75 .HostMemory("key") \ 76 .HostMemory("compilation_successful") \ 77 .HostMemory("resources"), \ 78 KERNEL); 79 80 #define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \ 81 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL); 82 83 #define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \ 84 REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \ 85 REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \ 86 REGISTER_KERNEL_BUILDER( \ 87 Name("_HostSend").Device(DEVICE).HostMemory("tensor"), SendOp); \ 88 REGISTER_KERNEL_BUILDER( \ 89 Name("_HostRecv").Device(DEVICE).HostMemory("tensor"), RecvOp); \ 90 REGISTER_KERNEL_BUILDER( \ 91 Name("_HostCast").Device(DEVICE).HostMemory("x").HostMemory("y"), \ 92 CpuCastOp); \ 93 REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE), NoOp); \ 94 REGISTER_KERNEL_BUILDER( \ 95 Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \ 96 ConstantOp); \ 97 REGISTER_KERNEL_BUILDER( \ 98 Name("HostConst").Device(DEVICE).HostMemory("output"), _HostConstantOp); \ 99 REGISTER_KERNEL_BUILDER( \ 100 Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \ 101 REGISTER_KERNEL_BUILDER( \ 102 Name("Identity").Device(DEVICE).TypeConstraint("T", DT_STRING), \ 103 IdentityOp); \ 104 REGISTER_KERNEL_BUILDER( \ 105 Name("Identity").Device(DEVICE).TypeConstraint<Variant>("T"), \ 106 IdentityOp); \ 107 REGISTER_KERNEL_BUILDER(Name("Identity") \ 108 .Device(DEVICE) \ 109 .TypeConstraint<ResourceHandle>("T") \ 110 .HostMemory("input") \ 111 .HostMemory("output"), \ 112 IdentityOp); \ 113 REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE), IdentityNOp); \ 114 REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \ 115 REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \ 116 PlaceholderOp); \ 117 \ 118 REGISTER_KERNEL_BUILDER( \ 119 Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \ 120 ResourceHandleOp<Var>); \ 121 REGISTER_KERNEL_BUILDER( \ 122 Name("_VarHandlesOp").Device(DEVICE).HostMemory("resources"), \ 123 ResourceHandlesOp<Var>); \ 124 REGISTER_KERNEL_BUILDER( \ 125 Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \ 126 ReadVariableOp); \ 127 REGISTER_KERNEL_BUILDER( \ 128 Name("_ReadVariablesOp").Device(DEVICE).HostMemory("resources"), \ 129 ReadVariablesOp); \ 130 REGISTER_KERNEL_BUILDER( \ 131 Name("DestroyResourceOp").Device(DEVICE).HostMemory("resource"), \ 132 DestroyResourceOp); \ 133 REGISTER_KERNEL_BUILDER(Name("Shape") \ 134 .Device(DEVICE) \ 135 .HostMemory("output") \ 136 .TypeConstraint<int32>("out_type") \ 137 .TypeConstraint("T", TYPES), \ 138 ShapeOp<int32>); \ 139 REGISTER_KERNEL_BUILDER(Name("Shape") \ 140 .Device(DEVICE) \ 141 .HostMemory("output") \ 142 .TypeConstraint<int64>("out_type") \ 143 .TypeConstraint("T", TYPES), \ 144 ShapeOp<int64>); \ 145 REGISTER_KERNEL_BUILDER(Name("ShapeN") \ 146 .Device(DEVICE) \ 147 .HostMemory("output") \ 148 .TypeConstraint<int32>("out_type") \ 149 .TypeConstraint("T", TYPES), \ 150 ShapeNOp<int32>); \ 151 REGISTER_KERNEL_BUILDER(Name("ShapeN") \ 152 .Device(DEVICE) \ 153 .HostMemory("output") \ 154 .TypeConstraint<int64>("out_type") \ 155 .TypeConstraint("T", TYPES), \ 156 ShapeNOp<int64>); \ 157 REGISTER_KERNEL_BUILDER(Name("Size") \ 158 .Device(DEVICE) \ 159 .HostMemory("output") \ 160 .TypeConstraint<int32>("out_type") \ 161 .TypeConstraint("T", TYPES), \ 162 SizeOp<int32>); \ 163 REGISTER_KERNEL_BUILDER(Name("Size") \ 164 .Device(DEVICE) \ 165 .HostMemory("output") \ 166 .TypeConstraint<int64>("out_type") \ 167 .TypeConstraint("T", TYPES), \ 168 SizeOp<int64>); \ 169 REGISTER_KERNEL_BUILDER( \ 170 Name("Rank").Device(DEVICE).HostMemory("output").TypeConstraint("T", \ 171 TYPES), \ 172 RankOp); \ 173 REGISTER_KERNEL_BUILDER( \ 174 Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"), \ 175 XlaAssignVariableOp); \ 176 REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \ 177 ControlTriggerOp); \ 178 REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \ 179 SwitchOp); \ 180 REGISTER_KERNEL_BUILDER( \ 181 Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \ 182 REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \ 183 REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \ 184 REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \ 185 NextIterationOp); \ 186 REGISTER_KERNEL_BUILDER(Name("LoopCond") \ 187 .Device(DEVICE) \ 188 .HostMemory("input") \ 189 .HostMemory("output"), \ 190 LoopCondOp); \ 191 \ 192 REGISTER_KERNEL_BUILDER( \ 193 Name("QueueEnqueueV2").Device(DEVICE).HostMemory("handle"), EnqueueOp); \ 194 REGISTER_KERNEL_BUILDER( \ 195 Name("QueueDequeueV2").Device(DEVICE).HostMemory("handle"), DequeueOp); \ 196 REGISTER_KERNEL_BUILDER( \ 197 Name("QueueCloseV2").Device(DEVICE).HostMemory("handle"), QueueCloseOp); \ 198 REGISTER_KERNEL_BUILDER(Name("QueueSizeV2") \ 199 .Device(DEVICE) \ 200 .HostMemory("size") \ 201 .HostMemory("handle"), \ 202 QueueSizeOp); \ 203 REGISTER_KERNEL_BUILDER( \ 204 Name("QueueIsClosedV2").Device(DEVICE).HostMemory("handle"), \ 205 QueueIsClosedOp); \ 206 \ 207 REGISTER_KERNEL_BUILDER( \ 208 Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \ 209 \ 210 REGISTER_KERNEL_BUILDER( \ 211 Name(kArgOp).Device(DEVICE).TypeConstraint("T", TYPES), ArgOp); \ 212 REGISTER_KERNEL_BUILDER(Name(kArgOp) \ 213 .Device(DEVICE) \ 214 .HostMemory("output") \ 215 .TypeConstraint<ResourceHandle>("T"), \ 216 ArgOp); \ 217 REGISTER_KERNEL_BUILDER( \ 218 Name(kArgOp).Device(DEVICE).TypeConstraint<Variant>("T"), ArgOp); \ 219 \ 220 REGISTER_KERNEL_BUILDER( \ 221 Name(kRetOp).Device(DEVICE).TypeConstraint("T", TYPES), RetvalOp); \ 222 REGISTER_KERNEL_BUILDER(Name(kRetOp) \ 223 .Device(DEVICE) \ 224 .TypeConstraint<ResourceHandle>("T") \ 225 .HostMemory("input"), \ 226 RetvalOp); \ 227 REGISTER_KERNEL_BUILDER( \ 228 Name(kDeviceRetOp).Device(DEVICE).TypeConstraint<int32>("T"), RetvalOp); \ 229 \ 230 REGISTER_KERNEL_BUILDER( \ 231 Name("RemoteCall").Device(DEVICE).HostMemory("target"), RemoteCallOp); \ 232 \ 233 REGISTER_KERNEL_BUILDER( \ 234 Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"), \ 235 data::GeneratorDatasetOp); \ 236 REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") \ 237 .Device(DEVICE) \ 238 .HostMemory("buffer_size") \ 239 .HostMemory("input_dataset") \ 240 .HostMemory("handle"), \ 241 data::PrefetchDatasetOp); \ 242 \ 243 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE), \ 244 data::IteratorHandleOp); \ 245 REGISTER_KERNEL_BUILDER( \ 246 Name("MakeIterator").Device(DEVICE).HostMemory("dataset"), \ 247 data::MakeIteratorOp); \ 248 REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \ 249 data::AnonymousIteratorHandleOp); \ 250 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ 251 data::IteratorGetNextOp); \ 252 REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \ 253 data::IteratorGetNextAsOptionalOp); \ 254 REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \ 255 data::IteratorGetNextSyncOp); \ 256 REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \ 257 .Device(DEVICE) \ 258 .HostMemory("string_handle"), \ 259 data::IteratorToStringHandleOp); \ 260 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2") \ 261 .Device(DEVICE) \ 262 .HostMemory("string_handle"), \ 263 data::IteratorFromStringHandleOp); \ 264 REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE), \ 265 data::OptionalNoneOp); \ 266 REGISTER_KERNEL_BUILDER(Name("OptionalFromValue").Device(DEVICE), \ 267 data::OptionalFromValueOp); \ 268 REGISTER_KERNEL_BUILDER( \ 269 Name("OptionalHasValue").Device(DEVICE).HostMemory("has_value"), \ 270 data::OptionalHasValueOp); \ 271 REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE), \ 272 data::OptionalGetValueOp); \ 273 REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \ 274 .Device(DEVICE) \ 275 .HostMemory("output") \ 276 .TypeConstraint<string>("T"), \ 277 ArgOp); \ 278 REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kRetOp) \ 279 .Device(DEVICE) \ 280 .TypeConstraint<string>("T") \ 281 .HostMemory("input"), \ 282 RetvalOp); \ 283 \ 284 REGISTER_KERNEL_BUILDER(Name("StackV2") \ 285 .Device(DEVICE) \ 286 .HostMemory("max_size") \ 287 .HostMemory("handle"), \ 288 StackOp); \ 289 REGISTER_KERNEL_BUILDER(Name("StackPushV2") \ 290 .Device(DEVICE) \ 291 .HostMemory("handle") \ 292 .TypeConstraint("T", TYPES), \ 293 TemplatedStackPushOp</*allow_swapping=*/false>); \ 294 REGISTER_KERNEL_BUILDER(Name("StackPopV2") \ 295 .Device(DEVICE) \ 296 .HostMemory("handle") \ 297 .TypeConstraint("elem_type", TYPES), \ 298 StackPopOp); \ 299 REGISTER_KERNEL_BUILDER( \ 300 Name("StackCloseV2").Device(DEVICE).HostMemory("handle"), StackCloseOp); 301 302 // TODO(b/118881356): currently we do not register the QueueEnqueueMany, 303 // QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read 304 // and write the tensors they access in order to concatenate them into a batch. 305 // We would need either to call out to an XLA computation to perform the 306 // concatenation, or we would need to refactor those kernels so the splitting 307 // or merging is done in a separate operator that can be compiled. 308 309 } // namespace tensorflow 310 311 #endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_ 312