• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_kernel.h"
17 
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/allocator.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/attr_value_util.h"
25 #include "tensorflow/core/framework/fake_input.h"
26 #include "tensorflow/core/framework/node_def_builder.h"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/framework/tensor_shape.pb.h"
29 #include "tensorflow/core/framework/tensor_util.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/graph/graph.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/protobuf.h"
38 #include "tensorflow/core/platform/test.h"
39 #include "tensorflow/core/platform/test_benchmark.h"
40 #include "tensorflow/core/public/version.h"
41 #include "tensorflow/core/util/device_name_utils.h"
42 
43 class DummyKernel : public tensorflow::OpKernel {
44  public:
DummyKernel(tensorflow::OpKernelConstruction * context)45   explicit DummyKernel(tensorflow::OpKernelConstruction* context)
46       : OpKernel(context) {}
Compute(tensorflow::OpKernelContext * context)47   void Compute(tensorflow::OpKernelContext* context) override {}
48 };
49 
50 // Test that registration works outside a namespace.
51 REGISTER_OP("Test1").Input("a: float").Input("b: int32").Output("o: uint8");
52 REGISTER_KERNEL_BUILDER(Name("Test1").Device(tensorflow::DEVICE_CPU),
53                         DummyKernel);
54 
55 namespace foo {
56 bool match_signature_ = false;
57 
58 // Test that registration works inside a different namespace.
59 class TestOp2 : public ::tensorflow::OpKernel {
60  public:
TestOp2(::tensorflow::OpKernelConstruction * context)61   explicit TestOp2(::tensorflow::OpKernelConstruction* context)
62       : OpKernel(context) {
63     ::tensorflow::Status status = context->MatchSignature(
64         {::tensorflow::DT_INT32}, {::tensorflow::DT_INT32});
65     match_signature_ = status.ok();
66     context->SetStatus(status);
67   }
Compute(::tensorflow::OpKernelContext * context)68   void Compute(::tensorflow::OpKernelContext* context) override {}
69 };
70 
71 REGISTER_OP("Test2").Input("i: T").Output("o: T").Attr("T: type");
72 REGISTER_KERNEL_BUILDER(Name("Test2")
73                             .Device(::tensorflow::DEVICE_GPU)
74                             .HostMemory("i")
75                             .HostMemory("o"),
76                         TestOp2);
77 }  // namespace foo
78 
79 namespace tensorflow {
80 
81 // Two operations with the same name but different devices.
82 REGISTER_OP("Test3").Input("a: T").Input("b: T").Attr("T: type");
83 
84 class TestOp3Cpu : public tensorflow::OpKernel {
85  public:
TestOp3Cpu(OpKernelConstruction * context)86   explicit TestOp3Cpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)87   void Compute(OpKernelContext* context) override {}
88 };
89 
90 REGISTER_KERNEL_BUILDER(
91     Name("Test3").Device(DEVICE_CPU).TypeConstraint<int8>("T"), TestOp3Cpu);
92 
93 namespace {
94 
95 class TestOp3Gpu : public tensorflow::OpKernel {
96  public:
TestOp3Gpu(OpKernelConstruction * context)97   explicit TestOp3Gpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)98   void Compute(OpKernelContext* context) override {}
99 };
100 
101 REGISTER_KERNEL_BUILDER(
102     Name("Test3").Device(DEVICE_GPU).TypeConstraint<float>("T"), TestOp3Cpu);
103 
104 // An Op registered for both
105 REGISTER_OP("Test4").Input("i: float").Output("o: float");
106 REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_CPU), DummyKernel);
107 REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_GPU), DummyKernel);
108 
109 // Kernels with different priorities.
110 REGISTER_OP("Test5").Input("a: T").Input("b: T").Attr("T: type");
111 
112 REGISTER_OP("OpWithoutKernel").Input("a: T").Input("b: T").Attr("T: type");
113 
114 class TestOp5Cpu : public tensorflow::OpKernel {
115  public:
TestOp5Cpu(OpKernelConstruction * context)116   explicit TestOp5Cpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)117   void Compute(OpKernelContext* context) override {}
118 };
119 
120 REGISTER_KERNEL_BUILDER(Name("Test5").Device(DEVICE_CPU).Priority(2),
121                         TestOp5Cpu);
122 
123 class TestOp5Gpu : public tensorflow::OpKernel {
124  public:
TestOp5Gpu(OpKernelConstruction * context)125   explicit TestOp5Gpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)126   void Compute(OpKernelContext* context) override {}
127 };
128 
129 REGISTER_KERNEL_BUILDER(Name("Test5").Device(DEVICE_GPU).Priority(1),
130                         TestOp5Gpu);
131 
DeviceTypes()132 static std::vector<DeviceType> DeviceTypes() {
133   return {DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)};
134 }
135 
136 class OpKernelTest : public ::testing::Test {
137  public:
OpKernelTest()138   OpKernelTest() : device_(Env::Default()) {}
139 
140  protected:
CreateNodeDef(const string & op_type,const DataTypeVector & inputs,const string & device="")141   NodeDef CreateNodeDef(const string& op_type, const DataTypeVector& inputs,
142                         const string& device = "") {
143     NodeDefBuilder builder(op_type + "-op", op_type);
144     for (DataType dt : inputs) {
145       builder.Input(FakeInput(dt));
146     }
147     builder.Device(device);
148     NodeDef node_def;
149     TF_CHECK_OK(builder.Finalize(&node_def));
150     return node_def;
151   }
152 
ExpectEqual(const string & what,const DataTypeVector & expected,const DataTypeVector & observed)153   void ExpectEqual(const string& what, const DataTypeVector& expected,
154                    const DataTypeVector& observed) {
155     EXPECT_EQ(expected.size(), observed.size()) << what;
156     const size_t size = std::min(expected.size(), observed.size());
157     for (size_t i = 0; i < size; ++i) {
158       bool match = TypesCompatible(expected[i], observed[i]);
159       EXPECT_TRUE(match) << what << " i:" << i << ", expected: " << expected[i]
160                          << ", observed: " << observed[i];
161     }
162   }
163 
ExpectSuccess(const string & op_type,DeviceType device_type,const DataTypeVector & inputs,const DataTypeVector & outputs)164   void ExpectSuccess(const string& op_type, DeviceType device_type,
165                      const DataTypeVector& inputs,
166                      const DataTypeVector& outputs) {
167     Status status;
168     std::unique_ptr<OpKernel> op(CreateOpKernel(
169         std::move(device_type), &device_, cpu_allocator(),
170         CreateNodeDef(op_type, inputs), TF_GRAPH_DEF_VERSION, &status));
171     EXPECT_TRUE(status.ok()) << status;
172     EXPECT_TRUE(op != nullptr);
173     if (op != nullptr) {
174       ExpectEqual("inputs", op->input_types(), inputs);
175       ExpectEqual("outputs", op->output_types(), outputs);
176     }
177   }
178 
ExpectFailure(const string & ascii_node_def,DeviceType device_type,error::Code code)179   void ExpectFailure(const string& ascii_node_def, DeviceType device_type,
180                      error::Code code) {
181     NodeDef node_def;
182     protobuf::TextFormat::ParseFromString(ascii_node_def, &node_def);
183     Status status;
184     std::unique_ptr<OpKernel> op(
185         CreateOpKernel(std::move(device_type), &device_, cpu_allocator(),
186                        node_def, TF_GRAPH_DEF_VERSION, &status));
187     EXPECT_TRUE(op == nullptr);
188     EXPECT_FALSE(status.ok());
189     if (!status.ok()) {
190       LOG(INFO) << "Status message: " << status.error_message();
191       EXPECT_EQ(code, status.code());
192     }
193   }
194 
195  private:
196   DeviceBase device_;
197 };
198 
TEST_F(OpKernelTest,SuccessCpu)199 TEST_F(OpKernelTest, SuccessCpu) {
200   ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT, DT_INT32}, {DT_UINT8});
201   ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT_REF, DT_INT32}, {DT_UINT8});
202 }
203 
TEST_F(OpKernelTest,SuccessGpu)204 TEST_F(OpKernelTest, SuccessGpu) {
205   foo::match_signature_ = false;
206   ExpectSuccess("Test2", DEVICE_GPU, {DT_INT32}, {DT_INT32});
207   EXPECT_TRUE(foo::match_signature_);
208 }
209 
TEST_F(OpKernelTest,SuccessBothCpuAndGpu)210 TEST_F(OpKernelTest, SuccessBothCpuAndGpu) {
211   ExpectSuccess("Test3", DEVICE_CPU, {DT_INT8, DT_INT8}, {});
212   ExpectSuccess("Test3", DEVICE_GPU, {DT_FLOAT, DT_FLOAT}, {});
213 }
214 
TEST_F(OpKernelTest,CpuTypeRegistered)215 TEST_F(OpKernelTest, CpuTypeRegistered) {
216   NodeDef ndef = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
217   PrioritizedDeviceTypeVector devs;
218   TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
219   EXPECT_EQ(1, devs.size());
220   EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
221 }
222 
TEST_F(OpKernelTest,KernelNotRegistered)223 TEST_F(OpKernelTest, KernelNotRegistered) {
224   const string& local_device = "/job:localhost/replica:0/task:0/device:CPU:0";
225   const string& remote_device = "/job:worker/replica:0/task:0/device";
226   {
227     // Try a node def of an op which does not have kernel. And the requested
228     // device in NodeDef is on a different address space than the local device.
229     NodeDef ndef =
230         CreateNodeDef("OpWithoutKernel", {DT_STRING, DT_STRING}, remote_device);
231     PrioritizedDeviceTypeVector devs;
232     DeviceNameUtils::ParsedName local_device_name;
233     DeviceNameUtils::ParseFullName(local_device, &local_device_name);
234     TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs,
235                                              &local_device_name));
236     EXPECT_EQ(2, devs.size());
237     EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first);
238     EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1].first);
239   }
240 
241   {
242     // Try a node def of an op which does not have kernel. And the requested
243     // device in NodeDef is on the same address space as the local device.
244     NodeDef ndef =
245         CreateNodeDef("OpWithoutKernel", {DT_STRING, DT_STRING}, local_device);
246     PrioritizedDeviceTypeVector devs;
247     DeviceNameUtils::ParsedName local_device_name;
248     DeviceNameUtils::ParseFullName(local_device, &local_device_name);
249     TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs,
250                                              &local_device_name));
251     EXPECT_EQ(0, devs.size());
252   }
253 }
254 
TEST_F(OpKernelTest,CpuAndGpuTypeRegistered)255 TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) {
256   {
257     // Try a node def of an op that is registered for a specific type
258     // only on CPU.
259     NodeDef ndef = CreateNodeDef("Test3", {DT_INT8, DT_INT8});
260     PrioritizedDeviceTypeVector devs;
261     TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
262     EXPECT_EQ(1, devs.size());
263     EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
264   }
265   {
266     // Try a node def of an op that is registered for a specific type
267     // only on GPU.
268     NodeDef ndef = CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT});
269     PrioritizedDeviceTypeVector devs;
270     TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
271     EXPECT_EQ(1, devs.size());
272     EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first);
273   }
274   {
275     // Try a node def of an op that is only registered for other types.
276     NodeDef ndef = CreateNodeDef("Test3", {DT_STRING, DT_STRING});
277     PrioritizedDeviceTypeVector devs;
278     TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
279     EXPECT_EQ(0, devs.size());
280   }
281 
282   {
283     // Try a node def of an op that is registered for both.
284     NodeDef ndef = CreateNodeDef("Test4", {DT_FLOAT});
285     PrioritizedDeviceTypeVector devs;
286     TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
287     EXPECT_EQ(2, devs.size());
288     EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first);
289     EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1].first);
290   }
291 
292   {
293     // Try a node def of an op where kernels have priorities.
294     NodeDef ndef = CreateNodeDef("Test5", {DT_STRING, DT_STRING});
295     PrioritizedDeviceTypeVector devs;
296     TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
297     EXPECT_EQ(2, devs.size());
298     EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
299     EXPECT_EQ(2, devs[0].second);
300     EXPECT_EQ(DeviceType(DEVICE_GPU), devs[1].first);
301     EXPECT_EQ(1, devs[1].second);
302   }
303 }
304 
TEST_F(OpKernelTest,NotFound)305 TEST_F(OpKernelTest, NotFound) {
306   const auto not_found = error::NOT_FOUND;
307   // Something with that op type name exists, but only with a
308   // different DeviceType.
309   ExpectFailure(CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}).DebugString(),
310                 DEVICE_GPU, not_found);
311   ExpectFailure(CreateNodeDef("Test3", {DT_INT8, DT_INT8}).DebugString(),
312                 DEVICE_GPU, not_found);
313   ExpectFailure(CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}).DebugString(),
314                 DEVICE_CPU, not_found);
315 
316   // No kernel with that signature registered.
317   ExpectFailure(CreateNodeDef("Test3", {DT_INT32, DT_INT32}).DebugString(),
318                 DEVICE_GPU, not_found);
319 
320   // Nothing with that op type name exists.
321   ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_CPU, not_found);
322   ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_GPU, not_found);
323 }
324 
TEST_F(OpKernelTest,TooFewInputs)325 TEST_F(OpKernelTest, TooFewInputs) {
326   const auto invalid = error::INVALID_ARGUMENT;
327   NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
328   node_def.clear_input();
329   ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
330   node_def.add_input("a");
331   ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
332 }
333 
TEST_F(OpKernelTest,TooManyInputs)334 TEST_F(OpKernelTest, TooManyInputs) {
335   const auto invalid = error::INVALID_ARGUMENT;
336   NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
337   node_def.add_input("c");
338   ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
339 }
340 
TEST_F(OpKernelTest,MatchSignatureFailes)341 TEST_F(OpKernelTest, MatchSignatureFailes) {
342   const auto invalid = error::INVALID_ARGUMENT;
343   foo::match_signature_ = true;
344   ExpectFailure(CreateNodeDef("Test2", {DT_FLOAT}).DebugString(), DEVICE_GPU,
345                 invalid);
346   EXPECT_FALSE(foo::match_signature_);
347 }
348 
349 class DummyDevice : public DeviceBase {
350  public:
DummyDevice(Env * env)351   explicit DummyDevice(Env* env) : DeviceBase(env) {}
GetAllocator(AllocatorAttributes)352   Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
353     return cpu_allocator();
354   }
355 };
356 
TEST_F(OpKernelTest,InputDtype)357 TEST_F(OpKernelTest, InputDtype) {
358   Env* env = Env::Default();
359   OpKernelContext::Params params;
360   DummyDevice device(env);
361   params.device = &device;
362   Status status;
363   std::unique_ptr<OpKernel> op(
364       CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
365                      CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
366                      TF_GRAPH_DEF_VERSION, &status));
367   EXPECT_TRUE(status.ok());
368   params.op_kernel = op.get();
369   Tensor a(DT_FLOAT, TensorShape({}));
370   Tensor b(DT_INT32, TensorShape({}));
371   Tensor c(DT_UINT8, TensorShape({}));
372   gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&a), TensorValue(&b),
373                                             TensorValue(&c)};
374   params.inputs = &inputs;
375   auto ctx = absl::make_unique<OpKernelContext>(&params);
376 
377   DataType dtype;
378   EXPECT_FALSE(ctx->input_dtype("non_existent_input", &dtype).ok());
379   ASSERT_TRUE(ctx->input_dtype("a", &dtype).ok());
380   EXPECT_EQ(dtype, DT_FLOAT);
381   ASSERT_TRUE(ctx->input_dtype("b", &dtype).ok());
382   EXPECT_EQ(dtype, DT_INT32);
383 }
384 
385 // A mock device that mimics the behavior of scoped allocator upon calling
386 // GetAllocator with a positive scope_id.
387 class ScopedAllocatorDevice : public DeviceBase {
388  public:
ScopedAllocatorDevice(Env * env)389   explicit ScopedAllocatorDevice(Env* env)
390       : DeviceBase(env),
391         scope_allocated_(false),
392         num_allocations_(0),
393         num_scoped_allocations_(0) {}
394 
GetAllocator(AllocatorAttributes attrs)395   Allocator* GetAllocator(AllocatorAttributes attrs) override {
396     CHECK_LE(attrs.scope_id, 0);
397     num_allocations_++;
398     return cpu_allocator();
399   }
400 
GetScopedAllocator(AllocatorAttributes attrs,int64_t)401   Allocator* GetScopedAllocator(AllocatorAttributes attrs,
402                                 int64_t /*step_id*/) override {
403     CHECK_GT(attrs.scope_id, 0);
404     num_scoped_allocations_++;
405     if (scope_allocated_) {
406       return nullptr;
407     } else {
408       scope_allocated_ = true;
409       return cpu_allocator();
410     }
411   }
412 
CopyTensorInSameDevice(const Tensor * input_tensor,Tensor * output_tensor,const DeviceContext * device_context,StatusCallback done)413   void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor,
414                               const DeviceContext* device_context,
415                               StatusCallback done) override {
416     CHECK(input_tensor->NumElements() == output_tensor->NumElements());
417     tensor::DeepCopy(*input_tensor, output_tensor);
418     done(Status::OK());
419   }
420 
421   // Return the count of calls to GetAllocator or GetScopedAllocator, depending
422   // on when scoped is false or true respectively.  For testing purposes.
num_allocations(bool scoped)423   int num_allocations(bool scoped) {
424     if (scoped) {
425       return num_scoped_allocations_;
426     } else {
427       return num_allocations_;
428     }
429   }
430 
431  private:
432   bool scope_allocated_;
433   int num_allocations_;
434   int num_scoped_allocations_;
435 };
436 
437 // Test that a kernel which has an output marked for allocation via
438 // ScopedAllocator, which calls allocate_temp and set_output, does the right
439 // thing.  In this case, the expected behavior is for allocate_temp to return
440 // a temporary buffer, and set_output to copy the contents of this temp buffer
441 // into the ScopedAllocator slice.
TEST_F(OpKernelTest,ScopedAllocationTest)442 TEST_F(OpKernelTest, ScopedAllocationTest) {
443   Env* env = Env::Default();
444   OpKernelContext::Params params;
445   auto sa_device = absl::make_unique<ScopedAllocatorDevice>(env);
446   params.device = sa_device.get();
447   Status status;
448   std::unique_ptr<OpKernel> op(CreateOpKernel(
449       DEVICE_CPU, params.device, cpu_allocator(),
450       CreateNodeDef("Test4", {DT_FLOAT}), TF_GRAPH_DEF_VERSION, &status));
451   EXPECT_TRUE(status.ok());
452   params.op_kernel = op.get();
453   AllocatorAttributes alloc_attrs;
454   alloc_attrs.scope_id = 1;
455   std::vector<AllocatorAttributes> output_alloc_attrs({alloc_attrs});
456   params.output_attr_array = output_alloc_attrs.data();
457   std::vector<int> forward_from({OpKernelContext::Params::kNeverForward});
458   params.forward_from_array = forward_from.data();
459   auto ctx = absl::make_unique<OpKernelContext>(&params);
460 
461   EXPECT_EQ(sa_device->num_allocations(false), 0);
462   EXPECT_EQ(sa_device->num_allocations(true), 0);
463   Tensor temp1;
464   TF_EXPECT_OK(
465       ctx->allocate_temp(DT_FLOAT, TensorShape({8}), &temp1, alloc_attrs));
466   EXPECT_EQ(sa_device->num_allocations(false), 1);
467   EXPECT_EQ(sa_device->num_allocations(true), 0);
468   Tensor temp2;
469   alloc_attrs.scope_id = -1;
470   TF_EXPECT_OK(
471       ctx->allocate_temp(DT_FLOAT, TensorShape({4}), &temp2, alloc_attrs));
472   EXPECT_EQ(sa_device->num_allocations(false), 2);
473   EXPECT_EQ(sa_device->num_allocations(true), 0);
474   ctx->set_output(0, temp1);
475   EXPECT_EQ(sa_device->num_allocations(false), 2);
476   EXPECT_EQ(sa_device->num_allocations(true), 1);
477 }
478 
479 class OpKernelBuilderTest : public ::testing::Test {
480  protected:
481   // Each attr is described by a "name|type|value".
CreateNodeDef(const string & op_type,const std::vector<string> & attrs)482   NodeDef CreateNodeDef(const string& op_type,
483                         const std::vector<string>& attrs) {
484     NodeDef node_def;
485     node_def.set_name(op_type + "-op");
486     node_def.set_op(op_type);
487     for (const string& attr_desc : attrs) {
488       std::vector<string> parts = str_util::Split(attr_desc, '|');
489       CHECK_EQ(parts.size(), 3);
490       AttrValue attr_value;
491       CHECK(ParseAttrValue(parts[1], parts[2], &attr_value)) << attr_desc;
492       node_def.mutable_attr()->insert(
493           AttrValueMap::value_type(parts[0], attr_value));
494     }
495     return node_def;
496   }
497 
ExpectSuccess(const string & op_type,const DeviceType & device_type,const std::vector<string> & attrs,DataTypeSlice input_types={})498   std::unique_ptr<OpKernel> ExpectSuccess(const string& op_type,
499                                           const DeviceType& device_type,
500                                           const std::vector<string>& attrs,
501                                           DataTypeSlice input_types = {}) {
502     Status status;
503     NodeDef def = CreateNodeDef(op_type, attrs);
504     for (size_t i = 0; i < input_types.size(); ++i) {
505       def.add_input("a:0");
506     }
507 
508     Env* env = Env::Default();
509     DeviceBase device(env);
510 
511     // Test CreateOpKernel()
512     std::unique_ptr<OpKernel> op(CreateOpKernel(device_type, &device,
513                                                 cpu_allocator(), def,
514                                                 TF_GRAPH_DEF_VERSION, &status));
515     EXPECT_TRUE(status.ok()) << status;
516     EXPECT_TRUE(op != nullptr);
517     if (op != nullptr) {
518       EXPECT_EQ(input_types.size(), op->num_inputs());
519       EXPECT_EQ(0, op->num_outputs());
520     }
521 
522     // Test SupportedDeviceTypesForNode()
523     PrioritizedDeviceTypeVector devices;
524     TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
525     bool found = false;
526     for (const auto& dt : devices) {
527       if (dt.first == device_type) {
528         found = true;
529       }
530     }
531     EXPECT_TRUE(found) << "Missing " << device_type << " from "
532                        << devices.size() << " devices.";
533 
534     // In case the caller wants to use the OpKernel
535     return op;
536   }
537 
ExpectFailure(const string & op_type,const DeviceType & device_type,const std::vector<string> & attrs,error::Code code)538   void ExpectFailure(const string& op_type, const DeviceType& device_type,
539                      const std::vector<string>& attrs, error::Code code) {
540     Status status;
541     const NodeDef def = CreateNodeDef(op_type, attrs);
542     Env* env = Env::Default();
543     DeviceBase device(env);
544 
545     // Test CreateOpKernel().
546     std::unique_ptr<OpKernel> op(CreateOpKernel(device_type, &device,
547                                                 cpu_allocator(), def,
548                                                 TF_GRAPH_DEF_VERSION, &status));
549     EXPECT_TRUE(op == nullptr);
550     EXPECT_FALSE(status.ok());
551     if (!status.ok()) {
552       LOG(INFO) << "Status message: " << status.error_message();
553       EXPECT_EQ(code, status.code());
554 
555       // Test SupportedDeviceTypesForNode().
556       PrioritizedDeviceTypeVector devices;
557       if (errors::IsNotFound(status)) {
558         TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
559         for (const auto& dt : devices) {
560           EXPECT_NE(dt.first, device_type);
561         }
562       } else {
563         Status status2 =
564             SupportedDeviceTypesForNode(DeviceTypes(), def, &devices);
565         EXPECT_EQ(status.code(), status2.code());
566       }
567     }
568   }
569 
GetKernelClassName(const string & op_type,const DeviceType & device_type,const std::vector<string> & attrs,DataTypeSlice input_types={})570   string GetKernelClassName(const string& op_type,
571                             const DeviceType& device_type,
572                             const std::vector<string>& attrs,
573                             DataTypeSlice input_types = {}) {
574     NodeDef def = CreateNodeDef(op_type, attrs);
575     for (size_t i = 0; i < input_types.size(); ++i) {
576       def.add_input("a:0");
577     }
578 
579     const KernelDef* kernel_def = nullptr;
580     string kernel_class_name;
581     const Status status =
582         FindKernelDef(device_type, def, &kernel_def, &kernel_class_name);
583     if (status.ok()) {
584       return kernel_class_name;
585     } else if (errors::IsNotFound(status)) {
586       return "not found";
587     } else {
588       return status.ToString();
589     }
590   }
591 };
592 
593 REGISTER_OP("BuildCPU");
594 REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel);
595 
TEST_F(OpKernelBuilderTest,BuilderCPU)596 TEST_F(OpKernelBuilderTest, BuilderCPU) {
597   ExpectSuccess("BuildCPU", DEVICE_CPU, {});
598   EXPECT_EQ("DummyKernel", GetKernelClassName("BuildCPU", DEVICE_CPU, {}));
599   ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND);
600   EXPECT_EQ("not found", GetKernelClassName("BuildCPU", DEVICE_GPU, {}));
601 }
602 
603 REGISTER_OP("BuildGPU");
604 REGISTER_KERNEL_BUILDER(Name("BuildGPU").Device(DEVICE_GPU), DummyKernel);
605 
TEST_F(OpKernelBuilderTest,BuilderGPU)606 TEST_F(OpKernelBuilderTest, BuilderGPU) {
607   ExpectFailure("BuildGPU", DEVICE_CPU, {}, error::NOT_FOUND);
608   ExpectSuccess("BuildGPU", DEVICE_GPU, {});
609 }
610 
611 REGISTER_OP("BuildBoth");
612 REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_CPU), DummyKernel);
613 REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_GPU), DummyKernel);
614 
TEST_F(OpKernelBuilderTest,BuilderBoth)615 TEST_F(OpKernelBuilderTest, BuilderBoth) {
616   ExpectSuccess("BuildBoth", DEVICE_CPU, {});
617   ExpectSuccess("BuildBoth", DEVICE_GPU, {});
618 }
619 
620 REGISTER_OP("BuildTypeAttr").Attr("T: type");
621 REGISTER_KERNEL_BUILDER(
622     Name("BuildTypeAttr").Device(DEVICE_CPU).TypeConstraint<float>("T"),
623     DummyKernel);
624 
TEST_F(OpKernelBuilderTest,BuilderTypeAttr)625 TEST_F(OpKernelBuilderTest, BuilderTypeAttr) {
626   ExpectSuccess("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_FLOAT"});
627   ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_BOOL"},
628                 error::NOT_FOUND);
629   ExpectFailure("BuildTypeAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
630   ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|int|7"},
631                 error::INVALID_ARGUMENT);
632 }
633 
634 REGISTER_OP("BuildTypeListAttr").Attr("T: list(type)");
635 REGISTER_KERNEL_BUILDER(
636     Name("BuildTypeListAttr").Device(DEVICE_CPU).TypeConstraint<bool>("T"),
637     DummyKernel);
638 
TEST_F(OpKernelBuilderTest,BuilderTypeListAttr)639 TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) {
640   ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"});
641   EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
642                                               {"T|list(type)|[]"}));
643 
644   ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"});
645   EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
646                                               {"T|list(type)|[]"}));
647 
648   ExpectSuccess("BuildTypeListAttr", DEVICE_CPU,
649                 {"T|list(type)|[DT_BOOL, DT_BOOL]"});
650 
651   ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"},
652                 error::NOT_FOUND);
653   EXPECT_EQ("not found", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
654                                             {"T|list(type)|[DT_FLOAT]"}));
655 
656   ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
657   EXPECT_TRUE(
658       absl::StrContains(GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {}),
659                         "Invalid argument: "));
660 
661   ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"},
662                 error::INVALID_ARGUMENT);
663 }
664 
665 REGISTER_OP("DuplicateKernel");
666 REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU),
667                         DummyKernel);
668 REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU),
669                         DummyKernel);
670 
TEST_F(OpKernelBuilderTest,DuplicateKernel)671 TEST_F(OpKernelBuilderTest, DuplicateKernel) {
672   const NodeDef ndef = CreateNodeDef("DuplicateKernel", {});
673   PrioritizedDeviceTypeVector devs;
674   Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
675   ASSERT_FALSE(status.ok());
676   EXPECT_TRUE(absl::StrContains(
677       status.error_message(), "Multiple OpKernel registrations match NodeDef"));
678 
679   ExpectFailure("DuplicateKernel", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
680 }
681 
682 REGISTER_OP("DuplicateKernelForT").Attr("T: type");
683 REGISTER_KERNEL_BUILDER(
684     Name("DuplicateKernelForT").Device(DEVICE_CPU).TypeConstraint<float>("T"),
685     DummyKernel);
686 REGISTER_KERNEL_BUILDER(
687     Name("DuplicateKernelForT").Device(DEVICE_CPU).TypeConstraint<float>("T"),
688     DummyKernel);
689 
TEST_F(OpKernelBuilderTest,DuplicateKernelForT)690 TEST_F(OpKernelBuilderTest, DuplicateKernelForT) {
691   const NodeDef ndef =
692       CreateNodeDef("DuplicateKernelForT", {"T|type|DT_FLOAT"});
693   PrioritizedDeviceTypeVector devs;
694   Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
695   ASSERT_FALSE(status.ok());
696   EXPECT_TRUE(absl::StrContains(
697       status.error_message(), "Multiple OpKernel registrations match NodeDef"));
698 
699   ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_FLOAT"},
700                 error::INVALID_ARGUMENT);
701   ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_BOOL"},
702                 error::NOT_FOUND);
703 }
704 
705 REGISTER_OP("BadConstraint").Attr("dtype: type");
706 REGISTER_KERNEL_BUILDER(Name("BadConstraint")
707                             .Device(DEVICE_CPU)
708                             // Mistake: "T" should be "dtype".
709                             .TypeConstraint<float>("T"),
710                         DummyKernel);
711 
TEST_F(OpKernelBuilderTest,BadConstraint)712 TEST_F(OpKernelBuilderTest, BadConstraint) {
713   const NodeDef ndef = CreateNodeDef("BadConstraint", {});
714   PrioritizedDeviceTypeVector devs;
715   Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
716   ASSERT_FALSE(status.ok());
717   EXPECT_TRUE(
718       absl::StrContains(status.error_message(),
719                         "OpKernel 'BadConstraint' has constraint on attr "
720                         "'T' not in NodeDef"));
721 
722   ExpectFailure("BadConstraint", DEVICE_CPU, {"dtype|type|DT_FLOAT"},
723                 error::INVALID_ARGUMENT);
724 }
725 
726 REGISTER_OP("ListOut").Output("a: int32").Output("b: T").Attr("T: list(type)");
727 REGISTER_KERNEL_BUILDER(Name("ListOut").Device(tensorflow::DEVICE_CPU),
728                         DummyKernel);
729 
TEST_F(OpKernelBuilderTest,OpOutputList)730 TEST_F(OpKernelBuilderTest, OpOutputList) {
731   Env* env = Env::Default();
732   OpKernelContext::Params params;
733   DummyDevice device(env);
734   params.device = &device;
735   Status status;
736   std::unique_ptr<OpKernel> op(CreateOpKernel(
737       DEVICE_CPU, params.device, cpu_allocator(),
738       CreateNodeDef("ListOut", {"T|list(type)|[DT_FLOAT, DT_INT32]"}),
739       TF_GRAPH_DEF_VERSION, &status));
740   EXPECT_TRUE(status.ok()) << status.ToString();
741   params.op_kernel = op.get();
742   gtl::InlinedVector<TensorValue, 4> inputs{};
743   params.inputs = &inputs;
744   auto ctx = absl::make_unique<OpKernelContext>(&params);
745 
746   EXPECT_EQ(DT_INT32, ctx->expected_output_dtype(0));
747   OpOutputList out_list;
748   EXPECT_FALSE(ctx->output_list("non_existent_output", &out_list).ok());
749   ASSERT_TRUE(ctx->output_list("b", &out_list).ok());
750   EXPECT_EQ(DT_FLOAT, out_list.expected_output_dtype(0));
751   EXPECT_EQ(DT_INT32, out_list.expected_output_dtype(1));
752 }
753 
754 class GetAttrKernel : public ::tensorflow::OpKernel {
755  public:
GetAttrKernel(OpKernelConstruction * context)756   explicit GetAttrKernel(OpKernelConstruction* context) : OpKernel(context) {
757     string attr_name;
758     OP_REQUIRES_OK(context, context->GetAttr("attr_name", &attr_name));
759 
760     status.emplace_back("s", context->GetAttr(attr_name, &s));
761     status.emplace_back("s_list", context->GetAttr(attr_name, &s_list));
762     status.emplace_back("i", context->GetAttr(attr_name, &i));
763     status.emplace_back("i_list", context->GetAttr(attr_name, &i_list));
764     status.emplace_back("i32", context->GetAttr(attr_name, &i32));
765     status.emplace_back("i32_list", context->GetAttr(attr_name, &i32_list));
766     status.emplace_back("f", context->GetAttr(attr_name, &f));
767     status.emplace_back("f_list", context->GetAttr(attr_name, &f_list));
768     status.emplace_back("b", context->GetAttr(attr_name, &b));
769     status.emplace_back("b_list", context->GetAttr(attr_name, &b_list));
770     status.emplace_back("type", context->GetAttr(attr_name, &type));
771     status.emplace_back("type_list", context->GetAttr(attr_name, &type_list));
772     status.emplace_back("type_vector",
773                         context->GetAttr(attr_name, &type_vector));
774     status.emplace_back("shape_proto",
775                         context->GetAttr(attr_name, &shape_proto));
776     status.emplace_back("shape_proto_list",
777                         context->GetAttr(attr_name, &shape_proto_list));
778     status.emplace_back("shape", context->GetAttr(attr_name, &shape));
779     status.emplace_back("shape_list", context->GetAttr(attr_name, &shape_list));
780   }
Compute(::tensorflow::OpKernelContext * context)781   void Compute(::tensorflow::OpKernelContext* context) override {}
782 
ExpectOk(std::initializer_list<string> keys)783   void ExpectOk(std::initializer_list<string> keys) {
784     for (const auto& key_status : status) {
785       // Only the status for keys in "keys" should be ok().
786       bool in_keys = false;
787       for (const string& key : keys) {
788         if (key_status.first == key) {
789           in_keys = true;
790         }
791       }
792       EXPECT_EQ(in_keys, key_status.second.ok())
793           << "key_status: " << key_status.first << ", " << key_status.second;
794     }
795   }
796 
797   string s;
798   std::vector<string> s_list;
799   int64 i;
800   std::vector<int64> i_list;
801   int32 i32;
802   std::vector<int32> i32_list;
803   float f;
804   std::vector<float> f_list;
805   bool b;
806   std::vector<bool> b_list;
807   DataType type;
808   std::vector<DataType> type_list;
809   DataTypeVector type_vector;
810   TensorShapeProto shape_proto;
811   std::vector<TensorShapeProto> shape_proto_list;
812   TensorShape shape;
813   std::vector<TensorShape> shape_list;
814   std::vector<std::pair<string, Status>> status;
815 };
816 
817 class GetAttrTest : public OpKernelBuilderTest {};
818 
819 REGISTER_OP("GetAttrStringList")
820     .Attr("attr_name: string")
821     .Attr("a: list(string)");
822 REGISTER_KERNEL_BUILDER(Name("GetAttrStringList").Device(DEVICE_CPU),
823                         GetAttrKernel);
824 
TEST_F(GetAttrTest,StringList)825 TEST_F(GetAttrTest, StringList) {
826   std::unique_ptr<OpKernel> op_kernel =
827       ExpectSuccess("GetAttrStringList", DEVICE_CPU,
828                     {"attr_name|string|'a'", "a|list(string)|['foo', 'bar']"});
829   auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
830   get_attr_kernel->ExpectOk({"s_list"});
831   EXPECT_EQ(std::vector<string>({"foo", "bar"}), get_attr_kernel->s_list);
832 
833   op_kernel = ExpectSuccess("GetAttrStringList", DEVICE_CPU,
834                             {"attr_name|string|'b'", "a|list(string)|['baz']"});
835   get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
836   get_attr_kernel->ExpectOk({});
837   EXPECT_TRUE(get_attr_kernel->s_list.empty());
838 }
839 
840 REGISTER_OP("GetAttrInt")
841     .Attr("attr_name: string")
842     .Attr("a: int")
843     .Attr("b: list(int)");
844 REGISTER_KERNEL_BUILDER(Name("GetAttrInt").Device(DEVICE_CPU), GetAttrKernel);
845 
TEST_F(GetAttrTest,Int)846 TEST_F(GetAttrTest, Int) {
847   std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
848       "GetAttrInt", DEVICE_CPU,
849       {"attr_name|string|'a'", "a|int|35", "b|list(int)|[-1, 2, -4]"});
850   auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
851   get_attr_kernel->ExpectOk({"i", "i32"});
852   EXPECT_EQ(35, get_attr_kernel->i);
853   EXPECT_EQ(35, get_attr_kernel->i32);
854 
855   op_kernel = ExpectSuccess(
856       "GetAttrInt", DEVICE_CPU,
857       {"attr_name|string|'b'", "a|int|35", "b|list(int)|[-1, 2, -4]"});
858   get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
859   get_attr_kernel->ExpectOk({"i_list", "i32_list"});
860   EXPECT_EQ(std::vector<int64>({-1, 2, -4}), get_attr_kernel->i_list);
861   EXPECT_EQ(std::vector<int32>({-1, 2, -4}), get_attr_kernel->i32_list);
862 
863   // 8589934592 == 2^33, too big to fit in an int32
864   op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU,
865                             {"attr_name|string|'a'", "a|int|8589934592",
866                              "b|list(int)|[-8589934592]"});
867   get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
868   get_attr_kernel->ExpectOk({"i"});  // no i32
869   EXPECT_EQ(8589934592ll, get_attr_kernel->i);
870   for (const auto& key_status : get_attr_kernel->status) {
871     if (key_status.first == "i32") {
872       EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code());
873       EXPECT_EQ("Attr a has value 8589934592 out of range for an int32",
874                 key_status.second.error_message());
875     }
876   }
877 
878   op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU,
879                             {"attr_name|string|'b'", "a|int|8589934592",
880                              "b|list(int)|[-8589934592]"});
881   get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
882   get_attr_kernel->ExpectOk({"i_list"});  // no i32_list
883   EXPECT_EQ(std::vector<int64>({-8589934592ll}), get_attr_kernel->i_list);
884   for (const auto& key_status : get_attr_kernel->status) {
885     if (key_status.first == "i32_list") {
886       EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code());
887       EXPECT_EQ("Attr b has value -8589934592 out of range for an int32",
888                 key_status.second.error_message());
889     }
890   }
891 }
892 
893 REGISTER_OP("GetAttrShape")
894     .Attr("attr_name: string")
895     .Attr("a: shape")
896     .Attr("b: list(shape)");
897 REGISTER_KERNEL_BUILDER(Name("GetAttrShape").Device(DEVICE_CPU), GetAttrKernel);
898 
TEST_F(GetAttrTest,Shape)899 TEST_F(GetAttrTest, Shape) {
900   std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
901       "GetAttrShape", DEVICE_CPU,
902       {"attr_name|string|'a'", "a|shape|{ dim { size: 3 } }",
903        "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"});
904   auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
905   get_attr_kernel->ExpectOk({"shape", "shape_proto"});
906   EXPECT_EQ(get_attr_kernel->shape_proto.ShortDebugString(), "dim { size: 3 }");
907   EXPECT_EQ("[3]", get_attr_kernel->shape.DebugString());
908 
909   op_kernel = ExpectSuccess(
910       "GetAttrShape", DEVICE_CPU,
911       {"attr_name|string|'b'", "a|shape|{ dim { size: 3 } }",
912        "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"});
913   get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
914   get_attr_kernel->ExpectOk({"shape_list", "shape_proto_list"});
915   ASSERT_EQ(2, get_attr_kernel->shape_proto_list.size());
916   EXPECT_EQ(get_attr_kernel->shape_proto_list[0].ShortDebugString(),
917             "dim { size: 2 }");
918   EXPECT_EQ(get_attr_kernel->shape_proto_list[1].ShortDebugString(),
919             "dim { size: 4 }");
920   ASSERT_EQ(2, get_attr_kernel->shape_list.size());
921   EXPECT_EQ("[2]", get_attr_kernel->shape_list[0].DebugString());
922   EXPECT_EQ("[4]", get_attr_kernel->shape_list[1].DebugString());
923 }
924 
925 REGISTER_OP("GetAttrType").Attr("attr_name: string").Attr("a: type");
926 REGISTER_KERNEL_BUILDER(Name("GetAttrType").Device(DEVICE_CPU), GetAttrKernel);
927 
TEST_F(GetAttrTest,Type)928 TEST_F(GetAttrTest, Type) {
929   std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
930       "GetAttrType", DEVICE_CPU, {"attr_name|string|'a'", "a|type|DT_FLOAT"});
931   auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
932   get_attr_kernel->ExpectOk({"type"});
933   EXPECT_EQ(DT_FLOAT, get_attr_kernel->type);
934 }
935 
936 REGISTER_OP("GetAttrTypeList").Attr("attr_name: string").Attr("a: list(type)");
937 REGISTER_KERNEL_BUILDER(Name("GetAttrTypeList").Device(DEVICE_CPU),
938                         GetAttrKernel);
939 
TEST_F(GetAttrTest,TypeList)940 TEST_F(GetAttrTest, TypeList) {
941   std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
942       "GetAttrTypeList", DEVICE_CPU,
943       {"attr_name|string|'a'", "a|list(type)|[DT_INT32, DT_BOOL]"});
944   auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
945 
946   get_attr_kernel->ExpectOk({"type_list", "type_vector"});
947   ASSERT_EQ(2, get_attr_kernel->type_list.size());
948   EXPECT_EQ(DT_INT32, get_attr_kernel->type_list[0]);
949   EXPECT_EQ(DT_BOOL, get_attr_kernel->type_list[1]);
950   ASSERT_EQ(2, get_attr_kernel->type_vector.size());
951   EXPECT_EQ(DT_INT32, get_attr_kernel->type_vector[0]);
952   EXPECT_EQ(DT_BOOL, get_attr_kernel->type_vector[1]);
953 }
954 
955 class BaseKernel : public ::tensorflow::OpKernel {
956  public:
BaseKernel(OpKernelConstruction * context)957   explicit BaseKernel(OpKernelConstruction* context) : OpKernel(context) {}
Compute(::tensorflow::OpKernelContext * context)958   void Compute(::tensorflow::OpKernelContext* context) override {}
959   virtual int Which() const = 0;
960 };
961 
962 template <int WHICH>
963 class LabeledKernel : public BaseKernel {
964  public:
965   using BaseKernel::BaseKernel;
Which() const966   int Which() const override { return WHICH; }
967 };
968 
969 class LabelTest : public OpKernelBuilderTest {};
970 
971 REGISTER_OP("LabeledKernel");
972 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU),
973                         LabeledKernel<0>);
974 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("one"),
975                         LabeledKernel<1>);
976 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"),
977                         LabeledKernel<2>);
978 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"),
979                         LabeledKernel<3>);
980 
TEST_F(LabelTest,Default)981 TEST_F(LabelTest, Default) {
982   std::unique_ptr<OpKernel> op_kernel =
983       ExpectSuccess("LabeledKernel", DEVICE_CPU, {});
984   auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
985   EXPECT_EQ(0, get_labeled_kernel->Which());
986 
987   EXPECT_EQ("LabeledKernel<0>",
988             GetKernelClassName("LabeledKernel", DEVICE_CPU, {}));
989 }
990 
TEST_F(LabelTest,Specified)991 TEST_F(LabelTest, Specified) {
992   std::unique_ptr<OpKernel> op_kernel =
993       ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"});
994   auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
995   EXPECT_EQ(1, get_labeled_kernel->Which());
996   EXPECT_EQ("LabeledKernel<1>", GetKernelClassName("LabeledKernel", DEVICE_CPU,
997                                                    {"_kernel|string|'one'"}));
998 }
999 
TEST_F(LabelTest,Duplicate)1000 TEST_F(LabelTest, Duplicate) {
1001   ExpectFailure("LabeledKernel", DEVICE_CPU, {"_kernel|string|'dupe'"},
1002                 error::INVALID_ARGUMENT);
1003 }
1004 
BM_InputRangeHelper(::testing::benchmark::State & state,const NodeDef & node_def,const char * input_name,int expected_start,int expected_stop)1005 void BM_InputRangeHelper(::testing::benchmark::State& state,
1006                          const NodeDef& node_def, const char* input_name,
1007                          int expected_start, int expected_stop) {
1008   Status status;
1009   auto device = absl::make_unique<DummyDevice>(Env::Default());
1010 
1011   std::unique_ptr<OpKernel> op(CreateOpKernel(DEVICE_CPU, device.get(),
1012                                               cpu_allocator(), node_def,
1013                                               TF_GRAPH_DEF_VERSION, &status));
1014   TF_CHECK_OK(status);
1015 
1016   for (auto s : state) {
1017     int start;
1018     int stop;
1019     TF_CHECK_OK(op->InputRange(input_name, &start, &stop));
1020     EXPECT_EQ(expected_start, start);
1021     EXPECT_EQ(expected_stop, stop);
1022   }
1023 }
1024 
1025 REGISTER_KERNEL_BUILDER(Name("ConcatV2").Device(DEVICE_CPU), DummyKernel);
1026 REGISTER_KERNEL_BUILDER(Name("Select").Device(DEVICE_CPU), DummyKernel);
1027 REGISTER_KERNEL_BUILDER(Name("MatMul").Device(DEVICE_CPU), DummyKernel);
1028 
BM_ConcatInputRange(::testing::benchmark::State & state)1029 void BM_ConcatInputRange(::testing::benchmark::State& state) {
1030   // Create a ConcatV2 NodeDef with 4 inputs (plus the axis).
1031   NodeDef node_def;
1032   node_def.set_name("concat-op");
1033   node_def.set_op("ConcatV2");
1034   AttrValue attr_N;
1035   attr_N.set_i(4);
1036   AttrValue attr_T;
1037   attr_T.set_type(DT_FLOAT);
1038   AttrValue attr_Tidx;
1039   attr_Tidx.set_type(DT_INT32);
1040   node_def.mutable_attr()->insert({"N", attr_N});
1041   node_def.mutable_attr()->insert({"T", attr_T});
1042   node_def.mutable_attr()->insert({"Tidx", attr_Tidx});
1043   for (size_t i = 0; i < 5; ++i) {
1044     node_def.add_input(strings::StrCat("a:", i));
1045   }
1046 
1047   BM_InputRangeHelper(state, node_def, "values", 0, 4);
1048 }
1049 
BM_SelectInputRange(::testing::benchmark::State & state)1050 void BM_SelectInputRange(::testing::benchmark::State& state) {
1051   // Create a Select NodeDef with 3 inputs.
1052   NodeDef node_def;
1053   node_def.set_name("select-op");
1054   node_def.set_op("Select");
1055   AttrValue attr_T;
1056   attr_T.set_type(DT_FLOAT);
1057   node_def.mutable_attr()->insert({"T", attr_T});
1058   for (size_t i = 0; i < 3; ++i) {
1059     node_def.add_input(strings::StrCat("a:", i));
1060   }
1061 
1062   BM_InputRangeHelper(state, node_def, "condition", 0, 1);
1063 }
1064 
BM_TraceString(::testing::benchmark::State & state)1065 void BM_TraceString(::testing::benchmark::State& state) {
1066   const int verbose = state.range(0);
1067 
1068   // Create a MatMul NodeDef with 2 inputs.
1069   NodeDef node_def;
1070   node_def.set_name("gradient_tape/model_1/dense_1/MatMul_1");
1071   node_def.set_op("MatMul");
1072   AttrValue transpose_a, transpose_b, attr_t;
1073   attr_t.set_type(DT_FLOAT);
1074   node_def.mutable_attr()->insert({"T", attr_t});
1075   transpose_a.set_b(true);
1076   node_def.mutable_attr()->insert({"transpose_a", transpose_a});
1077   transpose_b.set_b(true);
1078   node_def.mutable_attr()->insert({"transpose_b", transpose_b});
1079   for (size_t i = 0; i < 2; ++i) {
1080     node_def.add_input(strings::StrCat("a:", i));
1081   }
1082 
1083   // Build OpKernel and OpKernelContext
1084   Status status;
1085   auto device = absl::make_unique<DummyDevice>(Env::Default());
1086   std::unique_ptr<OpKernel> op(CreateOpKernel(DEVICE_CPU, device.get(),
1087                                               cpu_allocator(), node_def,
1088                                               TF_GRAPH_DEF_VERSION, &status));
1089   TF_CHECK_OK(status);
1090 
1091   OpKernelContext::Params params;
1092   params.device = device.get();
1093   params.op_kernel = op.get();
1094   Tensor a(DT_FLOAT, TensorShape({99000, 256}));
1095   Tensor b(DT_FLOAT, TensorShape({256, 256}));
1096   gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&a), TensorValue(&b)};
1097   params.inputs = &inputs;
1098   auto ctx = absl::make_unique<OpKernelContext>(&params);
1099 
1100   for (auto s : state) {
1101     auto trace = op->TraceString(*ctx, verbose);
1102   }
1103 }
1104 
1105 BENCHMARK(BM_ConcatInputRange);
1106 BENCHMARK(BM_SelectInputRange);
1107 BENCHMARK(BM_TraceString)->Arg(1)->Arg(0);
1108 
TEST(RegisteredKernels,CanCallGetAllRegisteredKernels)1109 TEST(RegisteredKernels, CanCallGetAllRegisteredKernels) {
1110   auto kernel_list = GetAllRegisteredKernels();
1111   auto all_registered_kernels = kernel_list.kernel();
1112   auto has_name_test1 = [](const KernelDef& k) { return k.op() == "Test1"; };
1113 
1114   // Verify we can find the "Test1" op registered above
1115   auto test1_it = std::find_if(all_registered_kernels.begin(),
1116                                all_registered_kernels.end(), has_name_test1);
1117   ASSERT_NE(test1_it, all_registered_kernels.end());
1118   EXPECT_EQ(test1_it->device_type(), "CPU");
1119 
1120   // Verify there was just one kernel
1121   ++test1_it;
1122   EXPECT_EQ(
1123       std::find_if(test1_it, all_registered_kernels.end(), has_name_test1),
1124       all_registered_kernels.end());
1125 }
1126 
1127 // Simple test just to check we can call LogAllRegisteredKernels
TEST(RegisteredKernels,CanLogAllRegisteredKernels)1128 TEST(RegisteredKernels, CanLogAllRegisteredKernels) {
1129   tensorflow::LogAllRegisteredKernels();
1130 }
1131 
TEST(RegisteredKernels,GetFilteredRegisteredKernels)1132 TEST(RegisteredKernels, GetFilteredRegisteredKernels) {
1133   auto has_name_test1 = [](const KernelDef& k) { return k.op() == "Test1"; };
1134   auto kernel_list = GetFilteredRegisteredKernels(has_name_test1);
1135   ASSERT_EQ(kernel_list.kernel_size(), 1);
1136   EXPECT_EQ(kernel_list.kernel(0).op(), "Test1");
1137   EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
1138 }
1139 
TEST(RegisteredKernels,GetRegisteredKernelsForOp)1140 TEST(RegisteredKernels, GetRegisteredKernelsForOp) {
1141   auto kernel_list = GetRegisteredKernelsForOp("Test1");
1142   ASSERT_EQ(kernel_list.kernel_size(), 1);
1143   EXPECT_EQ(kernel_list.kernel(0).op(), "Test1");
1144   EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
1145 }
1146 
1147 // EXTRACT_KERNEL_NAME_TO_STRING wraps TF_EXTRACT_KERNEL_NAME for testing
1148 // (it involves quite a bit of macro-magic).
1149 #define EXTRACT_KERNEL_NAME_TO_STRING_IMPL(name, kernel_builder, ...) name
1150 #define EXTRACT_KERNEL_NAME_TO_STRING(kernel_builder) \
1151   TF_EXTRACT_KERNEL_NAME(EXTRACT_KERNEL_NAME_TO_STRING_IMPL, kernel_builder)
1152 
TEST(RegisterKernelMacro,ExtractName)1153 TEST(RegisterKernelMacro, ExtractName) {
1154   static constexpr char const* kName = "Foo";
1155   static constexpr char const* kExtractedName =
1156       EXTRACT_KERNEL_NAME_TO_STRING(Name(kName).Label("Label"));
1157   EXPECT_THAT(kExtractedName, ::testing::StrEq(kName));
1158 }
1159 
1160 }  // namespace
1161 }  // namespace tensorflow
1162