• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/time/clock.h"
17 #include "absl/time/time.h"
18 #include "tensorflow/core/framework/common_shape_fns.h"
19 #include "tensorflow/core/framework/op.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/resource_handle.pb.h"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/platform/errors.h"
25 #include "tensorflow/core/platform/tensor_float_32_utils.h"
26 #include "tensorflow/core/public/version.h"
27 
28 namespace tensorflow {
29 
30 REGISTER_OP("KernelLabel")
31     .Output("result: string")
32     .SetShapeFn(shape_inference::ScalarShape);
33 
34 REGISTER_OP("KernelLabelRequired")
35     .Input("input: int32")
36     .Output("result: string")
__anon5e664abc0102(shape_inference::InferenceContext* c) 37     .SetShapeFn([](shape_inference::InferenceContext* c) {
38       shape_inference::ShapeHandle out;
39       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &out));
40       c->set_output(0, c->Scalar());
41       return OkStatus();
42     });
43 
44 REGISTER_OP("GraphDefVersion")
45     .Output("version: int32")
46     .SetIsStateful()
47     .SetShapeFn(shape_inference::ScalarShape);
48 
49 REGISTER_OP("RequiresOlderGraphVersion")
50     .Output("version: int32")
51     .SetIsStateful()
__anon5e664abc0202(shape_inference::InferenceContext* c) 52     .SetShapeFn([](shape_inference::InferenceContext* c) {
53       if (c->graph_def_version() != TF_GRAPH_DEF_VERSION - 1) {
54         return errors::InvalidArgument("Wrong graph version for shape");
55       }
56       return shape_inference::ScalarShape(c);
57     });
58 
59 REGISTER_OP("Old")
60     .SetShapeFn(shape_inference::UnknownShape)
61     .Deprecated(8, "For reasons");
62 
63 REGISTER_OP("GetDeadline")
64     .Output("deadline_from_epoch_micros: int64")
65     .SetShapeFn(shape_inference::UnknownShape);
66 
67 REGISTER_OP("SleepOp")
68     .Input("sleep_seconds: int32")
69     .SetShapeFn(shape_inference::UnknownShape);
70 
71 REGISTER_OP("SleepIdentityOp")
72     .Input("sleep_seconds: int32")
73     .Input("input: T")
74     .Output("output: T")
75     .Attr("T: type")
76     .SetShapeFn(shape_inference::UnchangedShape);
77 
78 REGISTER_RESOURCE_HANDLE_OP(StubResource);
79 
80 REGISTER_OP("ResourceInitializedOp")
81     .Input("resource: resource")
82     .Output("initialized: bool")
83     .SetShapeFn(shape_inference::ScalarShape);
84 
85 REGISTER_OP("ResourceCreateOp")
86     .Input("resource: resource")
87     .SetShapeFn(shape_inference::UnknownShape);
88 
89 REGISTER_OP("ResourceUsingOp")
90     .Input("resource: resource")
91     .SetShapeFn(shape_inference::UnknownShape);
92 
93 REGISTER_OP("IsResourceHandleRefCounting")
94     .Input("handle: resource")
95     .Output("result: bool")
96     .SetShapeFn(shape_inference::ScalarShape);
97 
98 REGISTER_OP("MakeWeakResourceHandle")
99     .Input("handle: resource")
100     .Output("dup: resource")
101     .SetIsStateful()
102     .SetShapeFn(tensorflow::shape_inference::ScalarShape);
103 
104 REGISTER_OP("TestStringOutput")
105     .Input("input: float")
106     .Output("output1: float")
107     .Output("output2: string")
108     .SetShapeFn(shape_inference::UnknownShape);
109 
110 REGISTER_OP("Namespace>TestStringOutput")
111     .Input("input: float")
112     .Output("output1: float")
113     .Output("output2: string")
114     .SetShapeFn(shape_inference::UnknownShape);
115 
116 REGISTER_OP("TestAttr")
117     .Output("out: T")
118     .Attr("T: {float, double}")
119     .SetDoNotOptimize()
120     .SetShapeFn(shape_inference::UnknownShape);
121 
122 namespace {
123 enum KernelLabel { DEFAULT_LABEL, OVERLOAD_1_LABEL, OVERLOAD_2_LABEL };
124 }  // namespace
125 
126 template <KernelLabel KL>
127 class KernelLabelOp : public OpKernel {
128  public:
129   using OpKernel::OpKernel;
130 
Compute(OpKernelContext * ctx)131   void Compute(OpKernelContext* ctx) override {
132     Tensor* output;
133     OP_REQUIRES_OK(ctx,
134                    ctx->allocate_output("result", TensorShape({}), &output));
135     switch (KL) {
136       case DEFAULT_LABEL:
137         output->scalar<tstring>()() = "My label is: default";
138         break;
139       case OVERLOAD_1_LABEL:
140         output->scalar<tstring>()() = "My label is: overload_1";
141         break;
142       case OVERLOAD_2_LABEL:
143         output->scalar<tstring>()() = "My label is: overload_2";
144         break;
145     }
146   }
147 };
148 
149 REGISTER_KERNEL_BUILDER(Name("KernelLabel").Device(DEVICE_CPU),
150                         KernelLabelOp<DEFAULT_LABEL>);
151 REGISTER_KERNEL_BUILDER(
152     Name("KernelLabel").Device(DEVICE_CPU).Label("overload_1"),
153     KernelLabelOp<OVERLOAD_1_LABEL>);
154 REGISTER_KERNEL_BUILDER(
155     Name("KernelLabel").Device(DEVICE_CPU).Label("overload_2"),
156     KernelLabelOp<OVERLOAD_2_LABEL>);
157 
158 // All "KernelLabelRequired" kernels have labels
159 REGISTER_KERNEL_BUILDER(
160     Name("KernelLabelRequired").Device(DEVICE_CPU).Label("overload_1"),
161     KernelLabelOp<OVERLOAD_1_LABEL>);
162 REGISTER_KERNEL_BUILDER(
163     Name("KernelLabelRequired").Device(DEVICE_CPU).Label("overload_2"),
164     KernelLabelOp<OVERLOAD_2_LABEL>);
165 
166 class GraphDefVersionOp : public OpKernel {
167  public:
GraphDefVersionOp(OpKernelConstruction * ctx)168   explicit GraphDefVersionOp(OpKernelConstruction* ctx)
169       : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {}
170 
Compute(OpKernelContext * ctx)171   void Compute(OpKernelContext* ctx) override {
172     Tensor* output = nullptr;
173     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
174     output->scalar<int>()() = graph_def_version_;
175   }
176 
177  private:
178   const int graph_def_version_;
179 };
180 
181 REGISTER_KERNEL_BUILDER(Name("GraphDefVersion").Device(DEVICE_CPU),
182                         GraphDefVersionOp);
183 
184 class OldOp : public OpKernel {
185  public:
OldOp(OpKernelConstruction * ctx)186   explicit OldOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
187 
Compute(OpKernelContext * ctx)188   void Compute(OpKernelContext* ctx) override {}
189 };
190 
191 REGISTER_KERNEL_BUILDER(Name("Old").Device(DEVICE_CPU), OldOp);
192 
193 class GetDeadlineOp : public OpKernel {
194  public:
GetDeadlineOp(OpKernelConstruction * ctx)195   explicit GetDeadlineOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
196 
Compute(OpKernelContext * ctx)197   void Compute(OpKernelContext* ctx) override {
198     if (!ctx->deadline()) {
199       ctx->SetStatus(errors::InvalidArgument("Deadline has not ben set."));
200     }
201     Tensor* output;
202     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
203     output->scalar<int64_t>()() = absl::ToUnixMicros(*ctx->deadline());
204   }
205 };
206 
207 REGISTER_KERNEL_BUILDER(Name("GetDeadline").Device(DEVICE_CPU), GetDeadlineOp);
208 
209 class SleepOp : public OpKernel {
210  public:
SleepOp(OpKernelConstruction * ctx)211   explicit SleepOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
212 
Compute(OpKernelContext * ctx)213   void Compute(OpKernelContext* ctx) override {
214     absl::SleepFor(absl::Seconds(ctx->input(0).scalar<int>()()));
215   }
216 };
217 
218 REGISTER_KERNEL_BUILDER(Name("SleepOp").Device(DEVICE_CPU), SleepOp);
219 
220 class SleepIdentityOp : public OpKernel {
221  public:
SleepIdentityOp(OpKernelConstruction * ctx)222   explicit SleepIdentityOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
223 
Compute(OpKernelContext * ctx)224   void Compute(OpKernelContext* ctx) override {
225     absl::SleepFor(absl::Seconds(ctx->input(0).scalar<int>()()));
226     ctx->set_output(0, ctx->input(1));
227   }
228 };
229 
230 REGISTER_KERNEL_BUILDER(Name("SleepIdentityOp").Device(DEVICE_CPU),
231                         SleepIdentityOp);
232 
233 // Stubbed-out resource to test resource handle ops.
234 class StubResource : public ResourceBase {
235  public:
DebugString() const236   string DebugString() const override { return ""; }
237 };
238 
239 REGISTER_RESOURCE_HANDLE_KERNEL(StubResource);
240 
241 REGISTER_KERNEL_BUILDER(Name("ResourceInitializedOp").Device(DEVICE_CPU),
242                         IsResourceInitialized<StubResource>);
243 
244 class ResourceCreateOp : public OpKernel {
245  public:
ResourceCreateOp(OpKernelConstruction * c)246   explicit ResourceCreateOp(OpKernelConstruction* c) : OpKernel(c) {}
247 
Compute(OpKernelContext * c)248   void Compute(OpKernelContext* c) override {
249     OP_REQUIRES_OK(c,
250                    CreateResource(c, HandleFromInput(c, 0), new StubResource));
251   }
252 };
253 
254 REGISTER_KERNEL_BUILDER(Name("ResourceCreateOp").Device(DEVICE_CPU),
255                         ResourceCreateOp);
256 
257 // Uses a ResourceHandle to check its validity.
258 class ResourceUsingOp : public OpKernel {
259  public:
ResourceUsingOp(OpKernelConstruction * context)260   explicit ResourceUsingOp(OpKernelConstruction* context) : OpKernel(context) {}
261 
Compute(OpKernelContext * ctx)262   void Compute(OpKernelContext* ctx) override {
263     StubResource* unused;
264     OP_REQUIRES_OK(ctx, LookupResource<StubResource>(
265                             ctx, HandleFromInput(ctx, 0), &unused));
266   }
267 };
268 
269 REGISTER_KERNEL_BUILDER(Name("ResourceUsingOp").Device(DEVICE_CPU),
270                         ResourceUsingOp);
271 
272 class IsResourceHandleRefCountingOp : public OpKernel {
273  public:
IsResourceHandleRefCountingOp(OpKernelConstruction * ctx)274   explicit IsResourceHandleRefCountingOp(OpKernelConstruction* ctx)
275       : OpKernel(ctx) {}
276 
Compute(OpKernelContext * ctx)277   void Compute(OpKernelContext* ctx) override {
278     const auto& handle = HandleFromInput(ctx, 0);
279     Tensor* output;
280     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output));
281     output->flat<bool>()(0) = handle.IsRefCounting();
282   }
283 };
284 
285 REGISTER_KERNEL_BUILDER(Name("IsResourceHandleRefCounting").Device(DEVICE_CPU),
286                         IsResourceHandleRefCountingOp);
287 
288 // Duplicates a ResourceHandle as a weak ResourceHandle.
289 class MakeWeakResourceHandleOp : public OpKernel {
290  public:
MakeWeakResourceHandleOp(OpKernelConstruction * c)291   explicit MakeWeakResourceHandleOp(OpKernelConstruction* c) : OpKernel(c) {}
292 
Compute(OpKernelContext * ctx)293   void Compute(OpKernelContext* ctx) override {
294     Tensor tensor;
295     ResourceHandleProto proto;
296     HandleFromInput(ctx, 0).AsProto(&proto);
297 
298     AllocatorAttributes attr;
299     attr.set_on_host(true);
300     OP_REQUIRES_OK(
301         ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &tensor, attr));
302     tensor.scalar<ResourceHandle>()() = ResourceHandle{proto};
303     ctx->set_output(0, tensor);
304   }
305 };
306 
307 REGISTER_KERNEL_BUILDER(Name("MakeWeakResourceHandle").Device(DEVICE_CPU),
308                         MakeWeakResourceHandleOp);
309 REGISTER_KERNEL_BUILDER(Name("MakeWeakResourceHandle").Device(DEVICE_DEFAULT),
310                         MakeWeakResourceHandleOp);
311 
312 class TestAttrOp : public OpKernel {
313  public:
TestAttrOp(OpKernelConstruction * ctx)314   explicit TestAttrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
315 
Compute(OpKernelContext * ctx)316   void Compute(OpKernelContext* ctx) override {
317     Tensor* output;
318     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
319     output->scalar<float>()() = 1.0;
320   }
321 };
322 
323 REGISTER_KERNEL_BUILDER(
324     Name("TestAttr").Device(DEVICE_CPU).TypeConstraint<float>("T"), TestAttrOp);
325 
326 // Various test ops without kernels. These are used to test graph construction.
327 
328 REGISTER_OP("A")
329     .Output("out: float32")
330     .SetShapeFn(shape_inference::UnknownShape);
331 
332 REGISTER_OP("B")
333     .Output("out: float32")
334     .SetShapeFn(shape_inference::UnknownShape);
335 
336 REGISTER_OP("Foo1")
337     .Input("a: float32")
338     .Input("b: int32")
339     .Input("c: int32")
340     .Output("d: float32")
341     .Output("e: int32")
342     .SetShapeFn(shape_inference::UnknownShape);
343 
344 REGISTER_OP("Foo2")
345     .Input("a: float32")
346     .Input("b: string")
347     .Input("c: string")
348     .Output("d: float32")
349     .Output("e: int32")
350     .SetShapeFn(shape_inference::UnknownShape);
351 
352 REGISTER_OP("Foo3")
353     .Input("a: float32")
354     .Input("b: string")
355     .Input("c: float32")
356     .Output("d: float32")
357     .Output("e: int32")
358     .SetShapeFn(shape_inference::UnknownShape);
359 
360 REGISTER_OP("CopyOp").Input("a: T").Output("b: T").Attr("T: type").SetShapeFn(
361     shape_inference::UnknownShape);
362 
363 REGISTER_OP("None").SetShapeFn(shape_inference::UnknownShape);
364 
365 REGISTER_OP("IntOutput")
366     .Output("a: int32")
367     .SetShapeFn(shape_inference::UnknownShape);
368 
369 REGISTER_OP("Int64Output")
370     .Output("out: int64")
371     .SetShapeFn(shape_inference::UnknownShape);
372 
373 REGISTER_OP("RefOutput")
374     .Output("a: Ref(int32)")
375     .SetShapeFn(shape_inference::UnknownShape);
376 
377 REGISTER_OP("FloatOutput")
378     .Output("a: float32")
379     .SetShapeFn(shape_inference::UnknownShape);
380 
381 REGISTER_OP("TwoFloatOutputs")
382     .Output("a: float32")
383     .Output("b: float32")
384     .SetShapeFn(shape_inference::UnknownShape);
385 
386 REGISTER_OP("FiveFloatOutputs")
387     .Output("a: float32")
388     .Output("b: float32")
389     .Output("c: float32")
390     .Output("d: float32")
391     .Output("e: float32")
392     .SetShapeFn(shape_inference::UnknownShape);
393 
394 REGISTER_OP("RefOutputFloatOutput")
395     .Output("a: Ref(float32)")
396     .Output("b: float32")
397     .SetShapeFn(shape_inference::UnknownShape);
398 
399 REGISTER_OP("RefInputFloatInput")
400     .Input("a: Ref(float)")
401     .Input("b: float")
402     .SetShapeFn(shape_inference::UnknownShape);
403 
404 REGISTER_OP("IntInput")
405     .Input("a: int32")
406     .SetShapeFn(shape_inference::UnknownShape);
407 
408 REGISTER_OP("IntInputIntOutput")
409     .Input("a: int32")
410     .Output("b: int32")
411     .SetShapeFn(shape_inference::UnknownShape);
412 
413 REGISTER_OP("FloatInput")
414     .Input("a: float32")
415     .SetShapeFn(shape_inference::UnknownShape);
416 
417 REGISTER_OP("TwoIntOutputs")
418     .Output("a: int32")
419     .Output("b: int32")
420     .SetShapeFn(shape_inference::UnknownShape);
421 
422 REGISTER_OP("IntOutputFloatOutput")
423     .Output("a: int32")
424     .Output("b: float32")
425     .SetShapeFn(shape_inference::UnknownShape);
426 
427 REGISTER_OP("FloatOutputStringOutput")
428     .Output("a: float32")
429     .Output("b: string")
430     .SetShapeFn(shape_inference::UnknownShape);
431 
432 REGISTER_OP("TwoIntInputs")
433     .Input("a: int32")
434     .Input("b: int32")
435     .SetShapeFn(shape_inference::UnknownShape);
436 
437 REGISTER_OP("TwoFloatInputs")
438     .Input("a: float32")
439     .Input("b: float32")
440     .SetShapeFn(shape_inference::UnknownShape);
441 
442 REGISTER_OP("IntInputFloatInput")
443     .Input("a: int32")
444     .Input("b: float32")
445     .SetShapeFn(shape_inference::UnknownShape);
446 
447 REGISTER_OP("RefInputIntInput")
448     .Input("a: Ref(int32)")
449     .Input("b: int32")
450     .SetShapeFn(shape_inference::UnknownShape);
451 
452 REGISTER_OP("TwoFloatInputsFloatOutput")
453     .Input("a: float32")
454     .Input("b: float32")
455     .Output("c: float32")
456     .SetShapeFn(shape_inference::UnknownShape);
457 
458 REGISTER_OP("TwoFloatInputsIntOutput")
459     .Input("a: float32")
460     .Input("b: float32")
461     .Output("c: int32")
462     .SetShapeFn(shape_inference::UnknownShape);
463 
464 REGISTER_OP("RefInputFloatInputIntOutput")
465     .Input("a: Ref(float32)")
466     .Input("b: float32")
467     .Output("c: int32")
468     .SetShapeFn(shape_inference::UnknownShape);
469 
470 REGISTER_OP("ListInput")
471     .Input("a: N * T")
472     .Attr("N: int >= 1")
473     .Attr("T: type")
474     .SetShapeFn(shape_inference::UnknownShape);
475 
476 REGISTER_OP("ListOutput")
477     .Output("a: T")
478     .Attr("T: list(type) >= 1")
479     .SetShapeFn(shape_inference::UnknownShape);
480 
481 REGISTER_OP("Unary").Input("a: T").Output("b: T").Attr("T: type").SetShapeFn(
482     shape_inference::UnknownShape);
483 
484 REGISTER_OP("OpWithDefaultAttr")
485     .Output("a: int32")
486     .Attr("default_float: float = 123.0")
487     .SetShapeFn(shape_inference::UnknownShape);
488 
489 REGISTER_OP("OpWithFutureDefaultAttr")
490     .SetShapeFn(shape_inference::UnknownShape);
491 
492 REGISTER_OP("IntAttr")
493     .Output("out: int64")
494     .Attr("foo: int = 1")
495     .SetShapeFn(shape_inference::UnknownShape);
496 
497 REGISTER_OP("StringListAttr")
498     .Attr("a: list(string)")
499     .Attr("b: string")
500     .SetShapeFn(shape_inference::UnknownShape);
501 
502 REGISTER_OP("DefaultAttrs")
503     .Attr("string_val: string = 'abc'")
504     .Attr("string_list_val: list(string) = ['abc', '']")
505     .Attr("int_val: int = 123")
506     .Attr("int_list_val: list(int) = [1, 2, 3]")
507     .Attr("float_val: float = 10.0")
508     .Attr("float_list_val: list(float) = [10.0]")
509     .Attr("bool_val: bool = true")
510     .Attr("bool_list_val: list(bool) = [true, false]")
511     .Attr("type_val: type = DT_INT32")
512     .Attr("type_list_val: list(type) = [DT_INT32, DT_FLOAT]")
513     .Attr("shape_val: shape = { dim { size: 2 } dim { size: 1 } }")
514     .Attr("shape_list_val: list(shape) = [{}, { dim { size: 1} }]")
515     .Attr("tensor_val: tensor = { dtype: DT_INT32 tensor_shape: {} int_val: 1}")
516     .Attr(
517         "tensor_list_val: list(tensor) = "
518         "[{ dtype: DT_INT32 tensor_shape: {} int_val: 1}]")
519     .SetShapeFn(shape_inference::UnknownShape);
520 
521 REGISTER_OP("FuncAttr")
522     .Attr("f: func")
523     .SetShapeFn(shape_inference::UnknownShape);
524 
525 REGISTER_OP("FuncListAttr")
526     .Attr("f: list(func)")
527     .SetShapeFn(shape_inference::UnknownShape);
528 
529 REGISTER_OP("Simple")
530     .Input("a: int32")
531     .Output("out: float")
532     .SetShapeFn(shape_inference::UnknownShape);
533 
534 REGISTER_OP("OutT").Output("a: T").Attr("T: type").SetShapeFn(
535     shape_inference::UnknownShape);
536 
537 REGISTER_OP("ReservedInput")
538     .Input("input: int32")
539     .SetShapeFn(shape_inference::UnknownShape);
540 
541 REGISTER_OP("Polymorphic")
542     .Input("a: T")
543     .Output("out: T")
544     .Attr("T: type")
545     .SetShapeFn(shape_inference::UnknownShape);
546 
547 REGISTER_OP("PolymorphicOut")
548     .Output("out: T")
549     .Attr("T: type")
550     .SetShapeFn(shape_inference::UnknownShape);
551 
552 REGISTER_OP("PolymorphicDefaultOut")
553     .Output("out: T")
554     .Attr("T: type = DT_STRING")
555     .SetShapeFn(shape_inference::UnknownShape);
556 
557 REGISTER_OP("Binary")
558     .Input("a: T")
559     .Input("b: T")
560     .Output("out: T")
561     .Attr("T: type")
562     .SetShapeFn(shape_inference::UnknownShape);
563 
564 REGISTER_OP("Restrict")
565     .Input("a: T")
566     .Output("out: T")
567     .Attr("T: {string, bool}")
568     .SetShapeFn(shape_inference::UnknownShape);
569 
570 REGISTER_OP("TypeList")
571     .Input("a: T")
572     .Attr("T: list(type) >= 0")
573     .SetShapeFn(shape_inference::UnknownShape);
574 
575 REGISTER_OP("TypeListTwice")
576     .Input("a: T")
577     .Input("b: T")
578     .Attr("T: list(type) >= 0")
579     .SetShapeFn(shape_inference::UnknownShape);
580 
581 REGISTER_OP("OutTypeList")
582     .Output("out: T")
583     .Attr("T: list(type) >= 0")
584     .SetShapeFn(shape_inference::UnknownShape);
585 
586 REGISTER_OP("TypeListRestrict")
587     .Input("a: T")
588     .Attr("T: list({string, bool})")
589     .SetShapeFn(shape_inference::UnknownShape);
590 
591 REGISTER_OP("OutTypeListRestrict")
592     .Output("out: t")
593     .Attr("t: list({string, bool})")
594     .SetShapeFn(shape_inference::UnknownShape);
595 
596 REGISTER_OP("Attr").Attr("a: int").SetShapeFn(shape_inference::UnknownShape);
597 
598 REGISTER_OP("AttrFloat")
599     .Attr("a: float")
600     .SetShapeFn(shape_inference::UnknownShape);
601 
602 REGISTER_OP("AttrBool")
603     .Attr("a: bool")
604     .SetShapeFn(shape_inference::UnknownShape);
605 
606 REGISTER_OP("AttrBoolList")
607     .Attr("a: list(bool)")
608     .SetShapeFn(shape_inference::UnknownShape);
609 
610 REGISTER_OP("AttrMin")
611     .Attr("a: int >= 5")
612     .SetShapeFn(shape_inference::UnknownShape);
613 
614 REGISTER_OP("AttrListMin")
615     .Attr("a: list(int) >= 2")
616     .SetShapeFn(shape_inference::UnknownShape);
617 
618 REGISTER_OP("AttrEnum")
619     .Attr("a: {'apples', 'oranges'}")
620     .SetShapeFn(shape_inference::UnknownShape);
621 
622 REGISTER_OP("AttrEnumList")
623     .Attr("a: list({'apples', 'oranges'})")
624     .SetShapeFn(shape_inference::UnknownShape);
625 
626 REGISTER_OP("AttrShape")
627     .Attr("a: shape")
628     .SetShapeFn(shape_inference::UnknownShape);
629 
630 REGISTER_OP("AttrShapeList")
631     .Attr("a: list(shape)")
632     .SetShapeFn(shape_inference::UnknownShape);
633 
634 REGISTER_OP("AttrPartialShape")
635     .Attr("a: shape")
636     .SetShapeFn(shape_inference::UnknownShape);
637 
638 REGISTER_OP("AttrPartialShapeList")
639     .Attr("a: list(shape)")
640     .SetShapeFn(shape_inference::UnknownShape);
641 
642 REGISTER_OP("AttrDefault")
643     .Attr("a: string = 'banana'")
644     .SetShapeFn(shape_inference::UnknownShape);
645 
646 REGISTER_OP("AttrListDefault")
647     .Attr("a: list(int) = [5, 15]")
648     .SetShapeFn(shape_inference::UnknownShape);
649 
650 REGISTER_OP("AttrEmptyListDefault")
651     .Attr("a: list(float) = []")
652     .SetShapeFn(shape_inference::UnknownShape);
653 
654 REGISTER_OP("ReservedAttr")
655     .Attr("range: int")
656     .SetShapeFn(shape_inference::UnknownShape);
657 
658 REGISTER_OP("AttrTypeDefault")
659     .Input("a: T")
660     .Attr("T: type = DT_INT32")
661     .SetShapeFn(shape_inference::UnknownShape);
662 
663 REGISTER_OP("AttrListTypeDefault")
664     .Input("a: N * T")
665     .Input("b: N * T")
666     .Attr("T: type = DT_INT32")
667     .Attr("N: int")
668     .SetShapeFn(shape_inference::UnknownShape);
669 
670 REGISTER_OP("NIntsIn")
671     .Input("a: N * int32")
672     .Attr("N: int >= 2")
673     .SetShapeFn(shape_inference::UnknownShape);
674 
675 REGISTER_OP("NPolymorphicIn")
676     .Input("a: N * T")
677     .Attr("T: type")
678     .Attr("N: int >= 2")
679     .SetShapeFn(shape_inference::UnknownShape);
680 
681 REGISTER_OP("NPolymorphicRestrictIn")
682     .Input("a: N * T")
683     .Attr("T: {string, bool}")
684     .Attr("N: int >= 2")
685     .SetShapeFn(shape_inference::UnknownShape);
686 
687 REGISTER_OP("NInTwice")
688     .Input("a: N * int32")
689     .Input("b: N * string")
690     .Attr("N: int >= 0")
691     .SetShapeFn(shape_inference::UnknownShape);
692 
693 REGISTER_OP("NInPolymorphicTwice")
694     .Input("a: N * T")
695     .Input("b: N * T")
696     .Attr("T: type")
697     .Attr("N: int >= 0")
698     .SetShapeFn(shape_inference::UnknownShape);
699 
700 REGISTER_OP("NInTwoTypeVariables")
701     .Input("a: N * S")
702     .Input("b: N * T")
703     .Attr("S: type")
704     .Attr("T: type")
705     .Attr("N: int >= 0")
706     .SetShapeFn(shape_inference::UnknownShape);
707 
708 REGISTER_OP("InPolymorphicTwice")
709     .Input("a: N * T")
710     .Input("b: M * T")
711     .Attr("T: type = DT_INT32")
712     .Attr("N: int >= 0")
713     .Attr("M: int >= 0")
714     .SetShapeFn(shape_inference::UnknownShape);
715 
716 REGISTER_OP("NIntsOut")
717     .Output("a: N * int32")
718     .Attr("N: int >= 2")
719     .SetShapeFn(shape_inference::UnknownShape);
720 
721 REGISTER_OP("NIntsOutDefault")
722     .Output("a: N * int32")
723     .Attr("N: int >= 2 = 3")
724     .SetShapeFn(shape_inference::UnknownShape);
725 
726 REGISTER_OP("NPolymorphicOut")
727     .Output("a: N * T")
728     .Attr("T: type")
729     .Attr("N: int >= 2")
730     .SetShapeFn(shape_inference::UnknownShape);
731 
732 REGISTER_OP("NPolymorphicOutDefault")
733     .Output("a: N * T")
734     .Attr("T: type = DT_BOOL")
735     .Attr("N: int >= 2 = 2")
736     .SetShapeFn(shape_inference::UnknownShape);
737 
738 REGISTER_OP("NPolymorphicRestrictOut")
739     .Output("a: N * T")
740     .Attr("T: {string, bool}")
741     .Attr("N: int >= 2")
742     .SetShapeFn(shape_inference::UnknownShape);
743 
744 REGISTER_OP("RefIn")
745     .Input("a: Ref(T)")
746     .Attr("T: type")
747     .SetShapeFn(shape_inference::UnknownShape);
748 
749 REGISTER_OP("TwoRefsIn")
750     .Input("a: Ref(T)")
751     .Input("b: Ref(T)")
752     .Attr("T: type")
753     .SetShapeFn(shape_inference::UnknownShape);
754 
755 REGISTER_OP("RefOut")
756     .Output("a: Ref(T)")
757     .Attr("T: type")
758     .SetShapeFn(shape_inference::UnknownShape);
759 
760 REGISTER_OP("SimpleStruct")
761     .Output("a: n_a * int32")
762     .Attr("n_a: int >= 0")
763     .SetShapeFn(shape_inference::UnknownShape);
764 
765 REGISTER_OP("MixedStruct")
766     .Output("a: n_a * int32")
767     .Output("b: float")
768     .Attr("n_a: int >= 0")
769     .SetShapeFn(shape_inference::UnknownShape);
770 
771 REGISTER_OP("ComplexStruct")
772     .Output("a: n_a * int32")
773     .Output("b: n_b * int64")
774     .Output("c: t_c")
775     .Attr("n_a: int >= 0")
776     .Attr("n_b: int >= 0")
777     .Attr("t_c: list(type) >= 0")
778     .SetShapeFn(shape_inference::UnknownShape);
779 
780 // An op which returns its own device placement as a string, useful for testing
781 // where ops get placed.
782 REGISTER_OP("DevicePlacementOp")
783     .Output("device: string")
784     .SetIsStateful()
785     .SetShapeFn(shape_inference::ScalarShape);
786 
787 class DevicePlacementOp : public OpKernel {
788  public:
789   using OpKernel::OpKernel;
790 
Compute(OpKernelContext * ctx)791   void Compute(OpKernelContext* ctx) override {
792     Tensor* output;
793     OP_REQUIRES_OK(ctx,
794                    ctx->allocate_output("device", TensorShape({}), &output));
795     output->scalar<tstring>()() = ctx->device()->name();
796   }
797 };
798 
799 REGISTER_KERNEL_BUILDER(Name("DevicePlacementOp").Device(DEVICE_CPU),
800                         DevicePlacementOp);
801 REGISTER_KERNEL_BUILDER(Name("DevicePlacementOp").Device(DEVICE_DEFAULT),
802                         DevicePlacementOp);
803 
804 // An op which returns the dtype of the tensor it was passed in. It expects
805 // DT_UINT8.
806 REGISTER_OP("DtypeWithDefaultOp")
807     .Input("in: T")
808     .Attr("T: type = DT_UINT8")
809     .Output("dtype: string")
810     .SetIsStateful()
811     .SetShapeFn(shape_inference::ScalarShape);
812 
813 class DTypeWithDefaultOp : public OpKernel {
814  public:
815   using OpKernel::OpKernel;
816 
Compute(OpKernelContext * ctx)817   void Compute(OpKernelContext* ctx) override {
818     const Tensor& input = ctx->input(0);
819     Tensor* output;
820     OP_REQUIRES_OK(ctx,
821                    ctx->allocate_output("dtype", TensorShape({}), &output));
822     output->scalar<tstring>()() = tensorflow::DataTypeString(input.dtype());
823   }
824 };
825 
826 REGISTER_KERNEL_BUILDER(Name("DtypeWithDefaultOp").Device(DEVICE_CPU),
827                         DTypeWithDefaultOp);
828 
829 // An op that returns True if TensorFloat-32 execution is enabled. Useful for
830 // testing that enabling/disabling TensorFloat-32 works correctly, even when
831 // the test does not run with a GPU that supports TensorFloat-32.
832 REGISTER_OP("IsTensorFloat32Enabled")
833     .Output("enabled: bool")
834     .SetIsStateful()
835     .SetShapeFn(shape_inference::ScalarShape);
836 
837 class IsTensorFloat32Enabled : public OpKernel {
838  public:
839   using OpKernel::OpKernel;
840 
Compute(OpKernelContext * ctx)841   void Compute(OpKernelContext* ctx) override {
842     Tensor* output;
843     OP_REQUIRES_OK(ctx,
844                    ctx->allocate_output("enabled", TensorShape({}), &output));
845     output->scalar<bool>()() = tensor_float_32_execution_enabled();
846   }
847 };
848 
849 REGISTER_KERNEL_BUILDER(
850     Name("IsTensorFloat32Enabled").Device(DEVICE_CPU).HostMemory("enabled"),
851     IsTensorFloat32Enabled);
852 REGISTER_KERNEL_BUILDER(
853     Name("IsTensorFloat32Enabled").Device(DEVICE_GPU).HostMemory("enabled"),
854     IsTensorFloat32Enabled);
855 }  // end namespace tensorflow
856