1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "absl/time/clock.h" 17 #include "absl/time/time.h" 18 #include "tensorflow/core/framework/common_shape_fns.h" 19 #include "tensorflow/core/framework/op.h" 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/framework/resource_handle.pb.h" 22 #include "tensorflow/core/framework/resource_mgr.h" 23 #include "tensorflow/core/lib/core/status.h" 24 #include "tensorflow/core/platform/errors.h" 25 #include "tensorflow/core/platform/tensor_float_32_utils.h" 26 #include "tensorflow/core/public/version.h" 27 28 namespace tensorflow { 29 30 REGISTER_OP("KernelLabel") 31 .Output("result: string") 32 .SetShapeFn(shape_inference::ScalarShape); 33 34 REGISTER_OP("KernelLabelRequired") 35 .Input("input: int32") 36 .Output("result: string") __anon5e664abc0102(shape_inference::InferenceContext* c) 37 .SetShapeFn([](shape_inference::InferenceContext* c) { 38 shape_inference::ShapeHandle out; 39 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &out)); 40 c->set_output(0, c->Scalar()); 41 return OkStatus(); 42 }); 43 44 REGISTER_OP("GraphDefVersion") 45 .Output("version: int32") 46 .SetIsStateful() 47 .SetShapeFn(shape_inference::ScalarShape); 48 49 REGISTER_OP("RequiresOlderGraphVersion") 50 .Output("version: int32") 51 .SetIsStateful() __anon5e664abc0202(shape_inference::InferenceContext* c) 52 .SetShapeFn([](shape_inference::InferenceContext* c) { 53 if (c->graph_def_version() != TF_GRAPH_DEF_VERSION - 1) { 54 return errors::InvalidArgument("Wrong graph version for shape"); 55 } 56 return shape_inference::ScalarShape(c); 57 }); 58 59 REGISTER_OP("Old") 60 .SetShapeFn(shape_inference::UnknownShape) 61 .Deprecated(8, "For reasons"); 62 63 REGISTER_OP("GetDeadline") 64 .Output("deadline_from_epoch_micros: int64") 65 .SetShapeFn(shape_inference::UnknownShape); 66 67 REGISTER_OP("SleepOp") 68 .Input("sleep_seconds: int32") 69 .SetShapeFn(shape_inference::UnknownShape); 70 71 REGISTER_OP("SleepIdentityOp") 72 .Input("sleep_seconds: int32") 73 .Input("input: T") 74 .Output("output: T") 75 .Attr("T: type") 76 .SetShapeFn(shape_inference::UnchangedShape); 77 78 REGISTER_RESOURCE_HANDLE_OP(StubResource); 79 80 REGISTER_OP("ResourceInitializedOp") 81 .Input("resource: resource") 82 .Output("initialized: bool") 83 .SetShapeFn(shape_inference::ScalarShape); 84 85 REGISTER_OP("ResourceCreateOp") 86 .Input("resource: resource") 87 .SetShapeFn(shape_inference::UnknownShape); 88 89 REGISTER_OP("ResourceUsingOp") 90 .Input("resource: resource") 91 .SetShapeFn(shape_inference::UnknownShape); 92 93 REGISTER_OP("IsResourceHandleRefCounting") 94 .Input("handle: resource") 95 .Output("result: bool") 96 .SetShapeFn(shape_inference::ScalarShape); 97 98 REGISTER_OP("MakeWeakResourceHandle") 99 .Input("handle: resource") 100 .Output("dup: resource") 101 .SetIsStateful() 102 .SetShapeFn(tensorflow::shape_inference::ScalarShape); 103 104 REGISTER_OP("TestStringOutput") 105 .Input("input: float") 106 .Output("output1: float") 107 .Output("output2: string") 108 .SetShapeFn(shape_inference::UnknownShape); 109 110 REGISTER_OP("Namespace>TestStringOutput") 111 .Input("input: float") 112 .Output("output1: float") 113 .Output("output2: string") 114 .SetShapeFn(shape_inference::UnknownShape); 115 116 REGISTER_OP("TestAttr") 117 .Output("out: T") 118 .Attr("T: {float, double}") 119 .SetDoNotOptimize() 120 .SetShapeFn(shape_inference::UnknownShape); 121 122 namespace { 123 enum KernelLabel { DEFAULT_LABEL, OVERLOAD_1_LABEL, OVERLOAD_2_LABEL }; 124 } // namespace 125 126 template <KernelLabel KL> 127 class KernelLabelOp : public OpKernel { 128 public: 129 using OpKernel::OpKernel; 130 Compute(OpKernelContext * ctx)131 void Compute(OpKernelContext* ctx) override { 132 Tensor* output; 133 OP_REQUIRES_OK(ctx, 134 ctx->allocate_output("result", TensorShape({}), &output)); 135 switch (KL) { 136 case DEFAULT_LABEL: 137 output->scalar<tstring>()() = "My label is: default"; 138 break; 139 case OVERLOAD_1_LABEL: 140 output->scalar<tstring>()() = "My label is: overload_1"; 141 break; 142 case OVERLOAD_2_LABEL: 143 output->scalar<tstring>()() = "My label is: overload_2"; 144 break; 145 } 146 } 147 }; 148 149 REGISTER_KERNEL_BUILDER(Name("KernelLabel").Device(DEVICE_CPU), 150 KernelLabelOp<DEFAULT_LABEL>); 151 REGISTER_KERNEL_BUILDER( 152 Name("KernelLabel").Device(DEVICE_CPU).Label("overload_1"), 153 KernelLabelOp<OVERLOAD_1_LABEL>); 154 REGISTER_KERNEL_BUILDER( 155 Name("KernelLabel").Device(DEVICE_CPU).Label("overload_2"), 156 KernelLabelOp<OVERLOAD_2_LABEL>); 157 158 // All "KernelLabelRequired" kernels have labels 159 REGISTER_KERNEL_BUILDER( 160 Name("KernelLabelRequired").Device(DEVICE_CPU).Label("overload_1"), 161 KernelLabelOp<OVERLOAD_1_LABEL>); 162 REGISTER_KERNEL_BUILDER( 163 Name("KernelLabelRequired").Device(DEVICE_CPU).Label("overload_2"), 164 KernelLabelOp<OVERLOAD_2_LABEL>); 165 166 class GraphDefVersionOp : public OpKernel { 167 public: GraphDefVersionOp(OpKernelConstruction * ctx)168 explicit GraphDefVersionOp(OpKernelConstruction* ctx) 169 : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {} 170 Compute(OpKernelContext * ctx)171 void Compute(OpKernelContext* ctx) override { 172 Tensor* output = nullptr; 173 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); 174 output->scalar<int>()() = graph_def_version_; 175 } 176 177 private: 178 const int graph_def_version_; 179 }; 180 181 REGISTER_KERNEL_BUILDER(Name("GraphDefVersion").Device(DEVICE_CPU), 182 GraphDefVersionOp); 183 184 class OldOp : public OpKernel { 185 public: OldOp(OpKernelConstruction * ctx)186 explicit OldOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 187 Compute(OpKernelContext * ctx)188 void Compute(OpKernelContext* ctx) override {} 189 }; 190 191 REGISTER_KERNEL_BUILDER(Name("Old").Device(DEVICE_CPU), OldOp); 192 193 class GetDeadlineOp : public OpKernel { 194 public: GetDeadlineOp(OpKernelConstruction * ctx)195 explicit GetDeadlineOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 196 Compute(OpKernelContext * ctx)197 void Compute(OpKernelContext* ctx) override { 198 if (!ctx->deadline()) { 199 ctx->SetStatus(errors::InvalidArgument("Deadline has not ben set.")); 200 } 201 Tensor* output; 202 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); 203 output->scalar<int64_t>()() = absl::ToUnixMicros(*ctx->deadline()); 204 } 205 }; 206 207 REGISTER_KERNEL_BUILDER(Name("GetDeadline").Device(DEVICE_CPU), GetDeadlineOp); 208 209 class SleepOp : public OpKernel { 210 public: SleepOp(OpKernelConstruction * ctx)211 explicit SleepOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 212 Compute(OpKernelContext * ctx)213 void Compute(OpKernelContext* ctx) override { 214 absl::SleepFor(absl::Seconds(ctx->input(0).scalar<int>()())); 215 } 216 }; 217 218 REGISTER_KERNEL_BUILDER(Name("SleepOp").Device(DEVICE_CPU), SleepOp); 219 220 class SleepIdentityOp : public OpKernel { 221 public: SleepIdentityOp(OpKernelConstruction * ctx)222 explicit SleepIdentityOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 223 Compute(OpKernelContext * ctx)224 void Compute(OpKernelContext* ctx) override { 225 absl::SleepFor(absl::Seconds(ctx->input(0).scalar<int>()())); 226 ctx->set_output(0, ctx->input(1)); 227 } 228 }; 229 230 REGISTER_KERNEL_BUILDER(Name("SleepIdentityOp").Device(DEVICE_CPU), 231 SleepIdentityOp); 232 233 // Stubbed-out resource to test resource handle ops. 234 class StubResource : public ResourceBase { 235 public: DebugString() const236 string DebugString() const override { return ""; } 237 }; 238 239 REGISTER_RESOURCE_HANDLE_KERNEL(StubResource); 240 241 REGISTER_KERNEL_BUILDER(Name("ResourceInitializedOp").Device(DEVICE_CPU), 242 IsResourceInitialized<StubResource>); 243 244 class ResourceCreateOp : public OpKernel { 245 public: ResourceCreateOp(OpKernelConstruction * c)246 explicit ResourceCreateOp(OpKernelConstruction* c) : OpKernel(c) {} 247 Compute(OpKernelContext * c)248 void Compute(OpKernelContext* c) override { 249 OP_REQUIRES_OK(c, 250 CreateResource(c, HandleFromInput(c, 0), new StubResource)); 251 } 252 }; 253 254 REGISTER_KERNEL_BUILDER(Name("ResourceCreateOp").Device(DEVICE_CPU), 255 ResourceCreateOp); 256 257 // Uses a ResourceHandle to check its validity. 258 class ResourceUsingOp : public OpKernel { 259 public: ResourceUsingOp(OpKernelConstruction * context)260 explicit ResourceUsingOp(OpKernelConstruction* context) : OpKernel(context) {} 261 Compute(OpKernelContext * ctx)262 void Compute(OpKernelContext* ctx) override { 263 StubResource* unused; 264 OP_REQUIRES_OK(ctx, LookupResource<StubResource>( 265 ctx, HandleFromInput(ctx, 0), &unused)); 266 } 267 }; 268 269 REGISTER_KERNEL_BUILDER(Name("ResourceUsingOp").Device(DEVICE_CPU), 270 ResourceUsingOp); 271 272 class IsResourceHandleRefCountingOp : public OpKernel { 273 public: IsResourceHandleRefCountingOp(OpKernelConstruction * ctx)274 explicit IsResourceHandleRefCountingOp(OpKernelConstruction* ctx) 275 : OpKernel(ctx) {} 276 Compute(OpKernelContext * ctx)277 void Compute(OpKernelContext* ctx) override { 278 const auto& handle = HandleFromInput(ctx, 0); 279 Tensor* output; 280 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output)); 281 output->flat<bool>()(0) = handle.IsRefCounting(); 282 } 283 }; 284 285 REGISTER_KERNEL_BUILDER(Name("IsResourceHandleRefCounting").Device(DEVICE_CPU), 286 IsResourceHandleRefCountingOp); 287 288 // Duplicates a ResourceHandle as a weak ResourceHandle. 289 class MakeWeakResourceHandleOp : public OpKernel { 290 public: MakeWeakResourceHandleOp(OpKernelConstruction * c)291 explicit MakeWeakResourceHandleOp(OpKernelConstruction* c) : OpKernel(c) {} 292 Compute(OpKernelContext * ctx)293 void Compute(OpKernelContext* ctx) override { 294 Tensor tensor; 295 ResourceHandleProto proto; 296 HandleFromInput(ctx, 0).AsProto(&proto); 297 298 AllocatorAttributes attr; 299 attr.set_on_host(true); 300 OP_REQUIRES_OK( 301 ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &tensor, attr)); 302 tensor.scalar<ResourceHandle>()() = ResourceHandle{proto}; 303 ctx->set_output(0, tensor); 304 } 305 }; 306 307 REGISTER_KERNEL_BUILDER(Name("MakeWeakResourceHandle").Device(DEVICE_CPU), 308 MakeWeakResourceHandleOp); 309 REGISTER_KERNEL_BUILDER(Name("MakeWeakResourceHandle").Device(DEVICE_DEFAULT), 310 MakeWeakResourceHandleOp); 311 312 class TestAttrOp : public OpKernel { 313 public: TestAttrOp(OpKernelConstruction * ctx)314 explicit TestAttrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 315 Compute(OpKernelContext * ctx)316 void Compute(OpKernelContext* ctx) override { 317 Tensor* output; 318 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); 319 output->scalar<float>()() = 1.0; 320 } 321 }; 322 323 REGISTER_KERNEL_BUILDER( 324 Name("TestAttr").Device(DEVICE_CPU).TypeConstraint<float>("T"), TestAttrOp); 325 326 // Various test ops without kernels. These are used to test graph construction. 327 328 REGISTER_OP("A") 329 .Output("out: float32") 330 .SetShapeFn(shape_inference::UnknownShape); 331 332 REGISTER_OP("B") 333 .Output("out: float32") 334 .SetShapeFn(shape_inference::UnknownShape); 335 336 REGISTER_OP("Foo1") 337 .Input("a: float32") 338 .Input("b: int32") 339 .Input("c: int32") 340 .Output("d: float32") 341 .Output("e: int32") 342 .SetShapeFn(shape_inference::UnknownShape); 343 344 REGISTER_OP("Foo2") 345 .Input("a: float32") 346 .Input("b: string") 347 .Input("c: string") 348 .Output("d: float32") 349 .Output("e: int32") 350 .SetShapeFn(shape_inference::UnknownShape); 351 352 REGISTER_OP("Foo3") 353 .Input("a: float32") 354 .Input("b: string") 355 .Input("c: float32") 356 .Output("d: float32") 357 .Output("e: int32") 358 .SetShapeFn(shape_inference::UnknownShape); 359 360 REGISTER_OP("CopyOp").Input("a: T").Output("b: T").Attr("T: type").SetShapeFn( 361 shape_inference::UnknownShape); 362 363 REGISTER_OP("None").SetShapeFn(shape_inference::UnknownShape); 364 365 REGISTER_OP("IntOutput") 366 .Output("a: int32") 367 .SetShapeFn(shape_inference::UnknownShape); 368 369 REGISTER_OP("Int64Output") 370 .Output("out: int64") 371 .SetShapeFn(shape_inference::UnknownShape); 372 373 REGISTER_OP("RefOutput") 374 .Output("a: Ref(int32)") 375 .SetShapeFn(shape_inference::UnknownShape); 376 377 REGISTER_OP("FloatOutput") 378 .Output("a: float32") 379 .SetShapeFn(shape_inference::UnknownShape); 380 381 REGISTER_OP("TwoFloatOutputs") 382 .Output("a: float32") 383 .Output("b: float32") 384 .SetShapeFn(shape_inference::UnknownShape); 385 386 REGISTER_OP("FiveFloatOutputs") 387 .Output("a: float32") 388 .Output("b: float32") 389 .Output("c: float32") 390 .Output("d: float32") 391 .Output("e: float32") 392 .SetShapeFn(shape_inference::UnknownShape); 393 394 REGISTER_OP("RefOutputFloatOutput") 395 .Output("a: Ref(float32)") 396 .Output("b: float32") 397 .SetShapeFn(shape_inference::UnknownShape); 398 399 REGISTER_OP("RefInputFloatInput") 400 .Input("a: Ref(float)") 401 .Input("b: float") 402 .SetShapeFn(shape_inference::UnknownShape); 403 404 REGISTER_OP("IntInput") 405 .Input("a: int32") 406 .SetShapeFn(shape_inference::UnknownShape); 407 408 REGISTER_OP("IntInputIntOutput") 409 .Input("a: int32") 410 .Output("b: int32") 411 .SetShapeFn(shape_inference::UnknownShape); 412 413 REGISTER_OP("FloatInput") 414 .Input("a: float32") 415 .SetShapeFn(shape_inference::UnknownShape); 416 417 REGISTER_OP("TwoIntOutputs") 418 .Output("a: int32") 419 .Output("b: int32") 420 .SetShapeFn(shape_inference::UnknownShape); 421 422 REGISTER_OP("IntOutputFloatOutput") 423 .Output("a: int32") 424 .Output("b: float32") 425 .SetShapeFn(shape_inference::UnknownShape); 426 427 REGISTER_OP("FloatOutputStringOutput") 428 .Output("a: float32") 429 .Output("b: string") 430 .SetShapeFn(shape_inference::UnknownShape); 431 432 REGISTER_OP("TwoIntInputs") 433 .Input("a: int32") 434 .Input("b: int32") 435 .SetShapeFn(shape_inference::UnknownShape); 436 437 REGISTER_OP("TwoFloatInputs") 438 .Input("a: float32") 439 .Input("b: float32") 440 .SetShapeFn(shape_inference::UnknownShape); 441 442 REGISTER_OP("IntInputFloatInput") 443 .Input("a: int32") 444 .Input("b: float32") 445 .SetShapeFn(shape_inference::UnknownShape); 446 447 REGISTER_OP("RefInputIntInput") 448 .Input("a: Ref(int32)") 449 .Input("b: int32") 450 .SetShapeFn(shape_inference::UnknownShape); 451 452 REGISTER_OP("TwoFloatInputsFloatOutput") 453 .Input("a: float32") 454 .Input("b: float32") 455 .Output("c: float32") 456 .SetShapeFn(shape_inference::UnknownShape); 457 458 REGISTER_OP("TwoFloatInputsIntOutput") 459 .Input("a: float32") 460 .Input("b: float32") 461 .Output("c: int32") 462 .SetShapeFn(shape_inference::UnknownShape); 463 464 REGISTER_OP("RefInputFloatInputIntOutput") 465 .Input("a: Ref(float32)") 466 .Input("b: float32") 467 .Output("c: int32") 468 .SetShapeFn(shape_inference::UnknownShape); 469 470 REGISTER_OP("ListInput") 471 .Input("a: N * T") 472 .Attr("N: int >= 1") 473 .Attr("T: type") 474 .SetShapeFn(shape_inference::UnknownShape); 475 476 REGISTER_OP("ListOutput") 477 .Output("a: T") 478 .Attr("T: list(type) >= 1") 479 .SetShapeFn(shape_inference::UnknownShape); 480 481 REGISTER_OP("Unary").Input("a: T").Output("b: T").Attr("T: type").SetShapeFn( 482 shape_inference::UnknownShape); 483 484 REGISTER_OP("OpWithDefaultAttr") 485 .Output("a: int32") 486 .Attr("default_float: float = 123.0") 487 .SetShapeFn(shape_inference::UnknownShape); 488 489 REGISTER_OP("OpWithFutureDefaultAttr") 490 .SetShapeFn(shape_inference::UnknownShape); 491 492 REGISTER_OP("IntAttr") 493 .Output("out: int64") 494 .Attr("foo: int = 1") 495 .SetShapeFn(shape_inference::UnknownShape); 496 497 REGISTER_OP("StringListAttr") 498 .Attr("a: list(string)") 499 .Attr("b: string") 500 .SetShapeFn(shape_inference::UnknownShape); 501 502 REGISTER_OP("DefaultAttrs") 503 .Attr("string_val: string = 'abc'") 504 .Attr("string_list_val: list(string) = ['abc', '']") 505 .Attr("int_val: int = 123") 506 .Attr("int_list_val: list(int) = [1, 2, 3]") 507 .Attr("float_val: float = 10.0") 508 .Attr("float_list_val: list(float) = [10.0]") 509 .Attr("bool_val: bool = true") 510 .Attr("bool_list_val: list(bool) = [true, false]") 511 .Attr("type_val: type = DT_INT32") 512 .Attr("type_list_val: list(type) = [DT_INT32, DT_FLOAT]") 513 .Attr("shape_val: shape = { dim { size: 2 } dim { size: 1 } }") 514 .Attr("shape_list_val: list(shape) = [{}, { dim { size: 1} }]") 515 .Attr("tensor_val: tensor = { dtype: DT_INT32 tensor_shape: {} int_val: 1}") 516 .Attr( 517 "tensor_list_val: list(tensor) = " 518 "[{ dtype: DT_INT32 tensor_shape: {} int_val: 1}]") 519 .SetShapeFn(shape_inference::UnknownShape); 520 521 REGISTER_OP("FuncAttr") 522 .Attr("f: func") 523 .SetShapeFn(shape_inference::UnknownShape); 524 525 REGISTER_OP("FuncListAttr") 526 .Attr("f: list(func)") 527 .SetShapeFn(shape_inference::UnknownShape); 528 529 REGISTER_OP("Simple") 530 .Input("a: int32") 531 .Output("out: float") 532 .SetShapeFn(shape_inference::UnknownShape); 533 534 REGISTER_OP("OutT").Output("a: T").Attr("T: type").SetShapeFn( 535 shape_inference::UnknownShape); 536 537 REGISTER_OP("ReservedInput") 538 .Input("input: int32") 539 .SetShapeFn(shape_inference::UnknownShape); 540 541 REGISTER_OP("Polymorphic") 542 .Input("a: T") 543 .Output("out: T") 544 .Attr("T: type") 545 .SetShapeFn(shape_inference::UnknownShape); 546 547 REGISTER_OP("PolymorphicOut") 548 .Output("out: T") 549 .Attr("T: type") 550 .SetShapeFn(shape_inference::UnknownShape); 551 552 REGISTER_OP("PolymorphicDefaultOut") 553 .Output("out: T") 554 .Attr("T: type = DT_STRING") 555 .SetShapeFn(shape_inference::UnknownShape); 556 557 REGISTER_OP("Binary") 558 .Input("a: T") 559 .Input("b: T") 560 .Output("out: T") 561 .Attr("T: type") 562 .SetShapeFn(shape_inference::UnknownShape); 563 564 REGISTER_OP("Restrict") 565 .Input("a: T") 566 .Output("out: T") 567 .Attr("T: {string, bool}") 568 .SetShapeFn(shape_inference::UnknownShape); 569 570 REGISTER_OP("TypeList") 571 .Input("a: T") 572 .Attr("T: list(type) >= 0") 573 .SetShapeFn(shape_inference::UnknownShape); 574 575 REGISTER_OP("TypeListTwice") 576 .Input("a: T") 577 .Input("b: T") 578 .Attr("T: list(type) >= 0") 579 .SetShapeFn(shape_inference::UnknownShape); 580 581 REGISTER_OP("OutTypeList") 582 .Output("out: T") 583 .Attr("T: list(type) >= 0") 584 .SetShapeFn(shape_inference::UnknownShape); 585 586 REGISTER_OP("TypeListRestrict") 587 .Input("a: T") 588 .Attr("T: list({string, bool})") 589 .SetShapeFn(shape_inference::UnknownShape); 590 591 REGISTER_OP("OutTypeListRestrict") 592 .Output("out: t") 593 .Attr("t: list({string, bool})") 594 .SetShapeFn(shape_inference::UnknownShape); 595 596 REGISTER_OP("Attr").Attr("a: int").SetShapeFn(shape_inference::UnknownShape); 597 598 REGISTER_OP("AttrFloat") 599 .Attr("a: float") 600 .SetShapeFn(shape_inference::UnknownShape); 601 602 REGISTER_OP("AttrBool") 603 .Attr("a: bool") 604 .SetShapeFn(shape_inference::UnknownShape); 605 606 REGISTER_OP("AttrBoolList") 607 .Attr("a: list(bool)") 608 .SetShapeFn(shape_inference::UnknownShape); 609 610 REGISTER_OP("AttrMin") 611 .Attr("a: int >= 5") 612 .SetShapeFn(shape_inference::UnknownShape); 613 614 REGISTER_OP("AttrListMin") 615 .Attr("a: list(int) >= 2") 616 .SetShapeFn(shape_inference::UnknownShape); 617 618 REGISTER_OP("AttrEnum") 619 .Attr("a: {'apples', 'oranges'}") 620 .SetShapeFn(shape_inference::UnknownShape); 621 622 REGISTER_OP("AttrEnumList") 623 .Attr("a: list({'apples', 'oranges'})") 624 .SetShapeFn(shape_inference::UnknownShape); 625 626 REGISTER_OP("AttrShape") 627 .Attr("a: shape") 628 .SetShapeFn(shape_inference::UnknownShape); 629 630 REGISTER_OP("AttrShapeList") 631 .Attr("a: list(shape)") 632 .SetShapeFn(shape_inference::UnknownShape); 633 634 REGISTER_OP("AttrPartialShape") 635 .Attr("a: shape") 636 .SetShapeFn(shape_inference::UnknownShape); 637 638 REGISTER_OP("AttrPartialShapeList") 639 .Attr("a: list(shape)") 640 .SetShapeFn(shape_inference::UnknownShape); 641 642 REGISTER_OP("AttrDefault") 643 .Attr("a: string = 'banana'") 644 .SetShapeFn(shape_inference::UnknownShape); 645 646 REGISTER_OP("AttrListDefault") 647 .Attr("a: list(int) = [5, 15]") 648 .SetShapeFn(shape_inference::UnknownShape); 649 650 REGISTER_OP("AttrEmptyListDefault") 651 .Attr("a: list(float) = []") 652 .SetShapeFn(shape_inference::UnknownShape); 653 654 REGISTER_OP("ReservedAttr") 655 .Attr("range: int") 656 .SetShapeFn(shape_inference::UnknownShape); 657 658 REGISTER_OP("AttrTypeDefault") 659 .Input("a: T") 660 .Attr("T: type = DT_INT32") 661 .SetShapeFn(shape_inference::UnknownShape); 662 663 REGISTER_OP("AttrListTypeDefault") 664 .Input("a: N * T") 665 .Input("b: N * T") 666 .Attr("T: type = DT_INT32") 667 .Attr("N: int") 668 .SetShapeFn(shape_inference::UnknownShape); 669 670 REGISTER_OP("NIntsIn") 671 .Input("a: N * int32") 672 .Attr("N: int >= 2") 673 .SetShapeFn(shape_inference::UnknownShape); 674 675 REGISTER_OP("NPolymorphicIn") 676 .Input("a: N * T") 677 .Attr("T: type") 678 .Attr("N: int >= 2") 679 .SetShapeFn(shape_inference::UnknownShape); 680 681 REGISTER_OP("NPolymorphicRestrictIn") 682 .Input("a: N * T") 683 .Attr("T: {string, bool}") 684 .Attr("N: int >= 2") 685 .SetShapeFn(shape_inference::UnknownShape); 686 687 REGISTER_OP("NInTwice") 688 .Input("a: N * int32") 689 .Input("b: N * string") 690 .Attr("N: int >= 0") 691 .SetShapeFn(shape_inference::UnknownShape); 692 693 REGISTER_OP("NInPolymorphicTwice") 694 .Input("a: N * T") 695 .Input("b: N * T") 696 .Attr("T: type") 697 .Attr("N: int >= 0") 698 .SetShapeFn(shape_inference::UnknownShape); 699 700 REGISTER_OP("NInTwoTypeVariables") 701 .Input("a: N * S") 702 .Input("b: N * T") 703 .Attr("S: type") 704 .Attr("T: type") 705 .Attr("N: int >= 0") 706 .SetShapeFn(shape_inference::UnknownShape); 707 708 REGISTER_OP("InPolymorphicTwice") 709 .Input("a: N * T") 710 .Input("b: M * T") 711 .Attr("T: type = DT_INT32") 712 .Attr("N: int >= 0") 713 .Attr("M: int >= 0") 714 .SetShapeFn(shape_inference::UnknownShape); 715 716 REGISTER_OP("NIntsOut") 717 .Output("a: N * int32") 718 .Attr("N: int >= 2") 719 .SetShapeFn(shape_inference::UnknownShape); 720 721 REGISTER_OP("NIntsOutDefault") 722 .Output("a: N * int32") 723 .Attr("N: int >= 2 = 3") 724 .SetShapeFn(shape_inference::UnknownShape); 725 726 REGISTER_OP("NPolymorphicOut") 727 .Output("a: N * T") 728 .Attr("T: type") 729 .Attr("N: int >= 2") 730 .SetShapeFn(shape_inference::UnknownShape); 731 732 REGISTER_OP("NPolymorphicOutDefault") 733 .Output("a: N * T") 734 .Attr("T: type = DT_BOOL") 735 .Attr("N: int >= 2 = 2") 736 .SetShapeFn(shape_inference::UnknownShape); 737 738 REGISTER_OP("NPolymorphicRestrictOut") 739 .Output("a: N * T") 740 .Attr("T: {string, bool}") 741 .Attr("N: int >= 2") 742 .SetShapeFn(shape_inference::UnknownShape); 743 744 REGISTER_OP("RefIn") 745 .Input("a: Ref(T)") 746 .Attr("T: type") 747 .SetShapeFn(shape_inference::UnknownShape); 748 749 REGISTER_OP("TwoRefsIn") 750 .Input("a: Ref(T)") 751 .Input("b: Ref(T)") 752 .Attr("T: type") 753 .SetShapeFn(shape_inference::UnknownShape); 754 755 REGISTER_OP("RefOut") 756 .Output("a: Ref(T)") 757 .Attr("T: type") 758 .SetShapeFn(shape_inference::UnknownShape); 759 760 REGISTER_OP("SimpleStruct") 761 .Output("a: n_a * int32") 762 .Attr("n_a: int >= 0") 763 .SetShapeFn(shape_inference::UnknownShape); 764 765 REGISTER_OP("MixedStruct") 766 .Output("a: n_a * int32") 767 .Output("b: float") 768 .Attr("n_a: int >= 0") 769 .SetShapeFn(shape_inference::UnknownShape); 770 771 REGISTER_OP("ComplexStruct") 772 .Output("a: n_a * int32") 773 .Output("b: n_b * int64") 774 .Output("c: t_c") 775 .Attr("n_a: int >= 0") 776 .Attr("n_b: int >= 0") 777 .Attr("t_c: list(type) >= 0") 778 .SetShapeFn(shape_inference::UnknownShape); 779 780 // An op which returns its own device placement as a string, useful for testing 781 // where ops get placed. 782 REGISTER_OP("DevicePlacementOp") 783 .Output("device: string") 784 .SetIsStateful() 785 .SetShapeFn(shape_inference::ScalarShape); 786 787 class DevicePlacementOp : public OpKernel { 788 public: 789 using OpKernel::OpKernel; 790 Compute(OpKernelContext * ctx)791 void Compute(OpKernelContext* ctx) override { 792 Tensor* output; 793 OP_REQUIRES_OK(ctx, 794 ctx->allocate_output("device", TensorShape({}), &output)); 795 output->scalar<tstring>()() = ctx->device()->name(); 796 } 797 }; 798 799 REGISTER_KERNEL_BUILDER(Name("DevicePlacementOp").Device(DEVICE_CPU), 800 DevicePlacementOp); 801 REGISTER_KERNEL_BUILDER(Name("DevicePlacementOp").Device(DEVICE_DEFAULT), 802 DevicePlacementOp); 803 804 // An op which returns the dtype of the tensor it was passed in. It expects 805 // DT_UINT8. 806 REGISTER_OP("DtypeWithDefaultOp") 807 .Input("in: T") 808 .Attr("T: type = DT_UINT8") 809 .Output("dtype: string") 810 .SetIsStateful() 811 .SetShapeFn(shape_inference::ScalarShape); 812 813 class DTypeWithDefaultOp : public OpKernel { 814 public: 815 using OpKernel::OpKernel; 816 Compute(OpKernelContext * ctx)817 void Compute(OpKernelContext* ctx) override { 818 const Tensor& input = ctx->input(0); 819 Tensor* output; 820 OP_REQUIRES_OK(ctx, 821 ctx->allocate_output("dtype", TensorShape({}), &output)); 822 output->scalar<tstring>()() = tensorflow::DataTypeString(input.dtype()); 823 } 824 }; 825 826 REGISTER_KERNEL_BUILDER(Name("DtypeWithDefaultOp").Device(DEVICE_CPU), 827 DTypeWithDefaultOp); 828 829 // An op that returns True if TensorFloat-32 execution is enabled. Useful for 830 // testing that enabling/disabling TensorFloat-32 works correctly, even when 831 // the test does not run with a GPU that supports TensorFloat-32. 832 REGISTER_OP("IsTensorFloat32Enabled") 833 .Output("enabled: bool") 834 .SetIsStateful() 835 .SetShapeFn(shape_inference::ScalarShape); 836 837 class IsTensorFloat32Enabled : public OpKernel { 838 public: 839 using OpKernel::OpKernel; 840 Compute(OpKernelContext * ctx)841 void Compute(OpKernelContext* ctx) override { 842 Tensor* output; 843 OP_REQUIRES_OK(ctx, 844 ctx->allocate_output("enabled", TensorShape({}), &output)); 845 output->scalar<bool>()() = tensor_float_32_execution_enabled(); 846 } 847 }; 848 849 REGISTER_KERNEL_BUILDER( 850 Name("IsTensorFloat32Enabled").Device(DEVICE_CPU).HostMemory("enabled"), 851 IsTensorFloat32Enabled); 852 REGISTER_KERNEL_BUILDER( 853 Name("IsTensorFloat32Enabled").Device(DEVICE_GPU).HostMemory("enabled"), 854 IsTensorFloat32Enabled); 855 } // end namespace tensorflow 856