• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_kernel.h"
17 
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 #include "tensorflow/core/framework/allocator.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/fake_input.h"
25 #include "tensorflow/core/framework/node_def_builder.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/tensor_shape.pb.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/protobuf.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/platform/test_benchmark.h"
37 #include "tensorflow/core/public/version.h"
38 
39 class DummyKernel : public tensorflow::OpKernel {
40  public:
DummyKernel(tensorflow::OpKernelConstruction * context)41   explicit DummyKernel(tensorflow::OpKernelConstruction* context)
42       : OpKernel(context) {}
Compute(tensorflow::OpKernelContext * context)43   void Compute(tensorflow::OpKernelContext* context) override {}
44 };
45 
46 // Test that registration works outside a namespace.
47 REGISTER_OP("Test1").Input("a: float").Input("b: int32").Output("o: uint8");
48 REGISTER_KERNEL_BUILDER(Name("Test1").Device(tensorflow::DEVICE_CPU),
49                         DummyKernel);
50 
51 namespace foo {
52 bool match_signature_ = false;
53 
54 // Test that registration works inside a different namespace.
55 class TestOp2 : public ::tensorflow::OpKernel {
56  public:
TestOp2(::tensorflow::OpKernelConstruction * context)57   explicit TestOp2(::tensorflow::OpKernelConstruction* context)
58       : OpKernel(context) {
59     ::tensorflow::Status status = context->MatchSignature(
60         {::tensorflow::DT_INT32}, {::tensorflow::DT_INT32});
61     match_signature_ = status.ok();
62     context->SetStatus(status);
63   }
Compute(::tensorflow::OpKernelContext * context)64   void Compute(::tensorflow::OpKernelContext* context) override {}
65 };
66 
67 REGISTER_OP("Test2").Input("i: T").Output("o: T").Attr("T: type");
68 REGISTER_KERNEL_BUILDER(Name("Test2")
69                             .Device(::tensorflow::DEVICE_GPU)
70                             .HostMemory("i")
71                             .HostMemory("o"),
72                         TestOp2);
73 }  // namespace foo
74 
75 namespace tensorflow {
76 
77 // Two operations with the same name but different devices.
78 REGISTER_OP("Test3").Input("a: T").Input("b: T").Attr("T: type");
79 
80 class TestOp3Cpu : public tensorflow::OpKernel {
81  public:
TestOp3Cpu(OpKernelConstruction * context)82   explicit TestOp3Cpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)83   void Compute(OpKernelContext* context) override {}
84 };
85 
86 REGISTER_KERNEL_BUILDER(
87     Name("Test3").Device(DEVICE_CPU).TypeConstraint<int8>("T"), TestOp3Cpu);
88 
89 namespace {
90 
91 class TestOp3Gpu : public tensorflow::OpKernel {
92  public:
TestOp3Gpu(OpKernelConstruction * context)93   explicit TestOp3Gpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)94   void Compute(OpKernelContext* context) override {}
95 };
96 
97 REGISTER_KERNEL_BUILDER(
98     Name("Test3").Device(DEVICE_GPU).TypeConstraint<float>("T"), TestOp3Cpu);
99 
100 // An Op registered for both
101 REGISTER_OP("Test4").Input("i: float").Output("o: float");
102 REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_CPU), DummyKernel);
103 REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_GPU), DummyKernel);
104 
DeviceTypes()105 static std::vector<DeviceType> DeviceTypes() {
106   return {DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)};
107 }
108 
109 class OpKernelTest : public ::testing::Test {
110  public:
OpKernelTest()111   OpKernelTest() : device_(Env::Default()) {}
112 
113  protected:
CreateNodeDef(const string & op_type,const DataTypeVector & inputs)114   NodeDef CreateNodeDef(const string& op_type, const DataTypeVector& inputs) {
115     NodeDefBuilder builder(op_type + "-op", op_type);
116     for (DataType dt : inputs) {
117       builder.Input(FakeInput(dt));
118     }
119     NodeDef node_def;
120     TF_CHECK_OK(builder.Finalize(&node_def));
121     return node_def;
122   }
123 
ExpectEqual(const string & what,const DataTypeVector & expected,const DataTypeVector & observed)124   void ExpectEqual(const string& what, const DataTypeVector& expected,
125                    const DataTypeVector& observed) {
126     EXPECT_EQ(expected.size(), observed.size()) << what;
127     const size_t size = std::min(expected.size(), observed.size());
128     for (size_t i = 0; i < size; ++i) {
129       bool match = TypesCompatible(expected[i], observed[i]);
130       EXPECT_TRUE(match) << what << " i:" << i << ", expected: " << expected[i]
131                          << ", observed: " << observed[i];
132     }
133   }
134 
ExpectSuccess(const string & op_type,DeviceType device_type,const DataTypeVector & inputs,const DataTypeVector & outputs)135   void ExpectSuccess(const string& op_type, DeviceType device_type,
136                      const DataTypeVector& inputs,
137                      const DataTypeVector& outputs) {
138     Status status;
139     std::unique_ptr<OpKernel> op(CreateOpKernel(
140         std::move(device_type), &device_, cpu_allocator(),
141         CreateNodeDef(op_type, inputs), TF_GRAPH_DEF_VERSION, &status));
142     EXPECT_TRUE(status.ok()) << status;
143     EXPECT_TRUE(op != nullptr);
144     if (op != nullptr) {
145       ExpectEqual("inputs", op->input_types(), inputs);
146       ExpectEqual("outputs", op->output_types(), outputs);
147     }
148   }
149 
ExpectFailure(const string & ascii_node_def,DeviceType device_type,error::Code code)150   void ExpectFailure(const string& ascii_node_def, DeviceType device_type,
151                      error::Code code) {
152     NodeDef node_def;
153     protobuf::TextFormat::ParseFromString(ascii_node_def, &node_def);
154     Status status;
155     std::unique_ptr<OpKernel> op(
156         CreateOpKernel(std::move(device_type), &device_, cpu_allocator(),
157                        node_def, TF_GRAPH_DEF_VERSION, &status));
158     EXPECT_TRUE(op == nullptr);
159     EXPECT_FALSE(status.ok());
160     if (!status.ok()) {
161       LOG(INFO) << "Status message: " << status.error_message();
162       EXPECT_EQ(code, status.code());
163     }
164   }
165 
166  private:
167   DeviceBase device_;
168 };
169 
TEST_F(OpKernelTest,SuccessCpu)170 TEST_F(OpKernelTest, SuccessCpu) {
171   ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT, DT_INT32}, {DT_UINT8});
172   ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT_REF, DT_INT32}, {DT_UINT8});
173 }
174 
TEST_F(OpKernelTest,SuccessGpu)175 TEST_F(OpKernelTest, SuccessGpu) {
176   foo::match_signature_ = false;
177   ExpectSuccess("Test2", DEVICE_GPU, {DT_INT32}, {DT_INT32});
178   EXPECT_TRUE(foo::match_signature_);
179 }
180 
TEST_F(OpKernelTest,SuccessBothCpuAndGpu)181 TEST_F(OpKernelTest, SuccessBothCpuAndGpu) {
182   ExpectSuccess("Test3", DEVICE_CPU, {DT_INT8, DT_INT8}, {});
183   ExpectSuccess("Test3", DEVICE_GPU, {DT_FLOAT, DT_FLOAT}, {});
184 }
185 
TEST_F(OpKernelTest,CpuTypeRegistered)186 TEST_F(OpKernelTest, CpuTypeRegistered) {
187   NodeDef ndef = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
188   DeviceTypeVector devs;
189   TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
190   EXPECT_EQ(1, devs.size());
191   EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]);
192 }
193 
TEST_F(OpKernelTest,CpuAndGpuTypeRegistered)194 TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) {
195   {
196     // Try a node def of an op that is registered for a specific type
197     // only on CPU.
198     NodeDef ndef = CreateNodeDef("Test3", {DT_INT8, DT_INT8});
199     DeviceTypeVector devs;
200     TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
201     EXPECT_EQ(1, devs.size());
202     EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]);
203   }
204   {
205     // Try a node def of an op that is registered for a specific type
206     // only on GPU.
207     NodeDef ndef = CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT});
208     DeviceTypeVector devs;
209     TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
210     EXPECT_EQ(1, devs.size());
211     EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]);
212   }
213   {
214     // Try a node def of an op that is only registered for other types.
215     NodeDef ndef = CreateNodeDef("Test3", {DT_STRING, DT_STRING});
216     DeviceTypeVector devs;
217     TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
218     EXPECT_EQ(0, devs.size());
219   }
220 
221   {
222     // Try a node def of an op that is registered for both.
223     NodeDef ndef = CreateNodeDef("Test4", {DT_FLOAT});
224     DeviceTypeVector devs;
225     TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
226     EXPECT_EQ(2, devs.size());
227     EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]);
228     EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1]);
229   }
230 }
231 
TEST_F(OpKernelTest,NotFound)232 TEST_F(OpKernelTest, NotFound) {
233   const auto not_found = error::NOT_FOUND;
234   // Something with that op type name exists, but only with a
235   // different DeviceType.
236   ExpectFailure(CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}).DebugString(),
237                 DEVICE_GPU, not_found);
238   ExpectFailure(CreateNodeDef("Test3", {DT_INT8, DT_INT8}).DebugString(),
239                 DEVICE_GPU, not_found);
240   ExpectFailure(CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}).DebugString(),
241                 DEVICE_CPU, not_found);
242 
243   // No kernel with that signature registered.
244   ExpectFailure(CreateNodeDef("Test3", {DT_INT32, DT_INT32}).DebugString(),
245                 DEVICE_GPU, not_found);
246 
247   // Nothing with that op type name exists.
248   ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_CPU, not_found);
249   ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_GPU, not_found);
250 }
251 
TEST_F(OpKernelTest,TooFewInputs)252 TEST_F(OpKernelTest, TooFewInputs) {
253   const auto invalid = error::INVALID_ARGUMENT;
254   NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
255   node_def.clear_input();
256   ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
257   node_def.add_input("a");
258   ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
259 }
260 
TEST_F(OpKernelTest,TooManyInputs)261 TEST_F(OpKernelTest, TooManyInputs) {
262   const auto invalid = error::INVALID_ARGUMENT;
263   NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
264   node_def.add_input("c");
265   ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
266 }
267 
TEST_F(OpKernelTest,MatchSignatureFailes)268 TEST_F(OpKernelTest, MatchSignatureFailes) {
269   const auto invalid = error::INVALID_ARGUMENT;
270   foo::match_signature_ = true;
271   ExpectFailure(CreateNodeDef("Test2", {DT_FLOAT}).DebugString(), DEVICE_GPU,
272                 invalid);
273   EXPECT_FALSE(foo::match_signature_);
274 }
275 
276 class DummyDevice : public DeviceBase {
277  public:
DummyDevice(Env * env,bool save)278   DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
RequiresRecordingAccessedTensors() const279   bool RequiresRecordingAccessedTensors() const override { return save_; }
GetAllocator(AllocatorAttributes)280   Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
281     return cpu_allocator();
282   }
283 
284  private:
285   bool save_;
286 };
287 
TEST_F(OpKernelTest,SaveTempFalse)288 TEST_F(OpKernelTest, SaveTempFalse) {
289   Env* env = Env::Default();
290   OpKernelContext::Params params;
291   params.record_tensor_accesses = false;
292   params.device = new DummyDevice(env, params.record_tensor_accesses);
293   Status status;
294   std::unique_ptr<OpKernel> op(
295       CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
296                      CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
297                      TF_GRAPH_DEF_VERSION, &status));
298   EXPECT_TRUE(status.ok());
299   params.op_kernel = op.get();
300   OpKernelContext* ctx = new OpKernelContext(&params);
301 
302   Tensor t;
303   TF_EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t));
304 
305   TensorReferenceVector referenced_tensors;
306   ctx->retrieve_accessed_tensors(&referenced_tensors);
307   EXPECT_EQ(0, referenced_tensors.size());
308 
309   delete ctx;
310   delete params.device;
311 }
312 
TEST_F(OpKernelTest,SaveTempTrue)313 TEST_F(OpKernelTest, SaveTempTrue) {
314   Env* env = Env::Default();
315   OpKernelContext::Params params;
316   params.record_tensor_accesses = true;
317   params.device = new DummyDevice(env, params.record_tensor_accesses);
318   Status status;
319   std::unique_ptr<OpKernel> op(
320       CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
321                      CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
322                      TF_GRAPH_DEF_VERSION, &status));
323   EXPECT_TRUE(status.ok());
324   params.op_kernel = op.get();
325   OpKernelContext* ctx = new OpKernelContext(&params);
326 
327   Tensor t;
328   TF_EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t));
329 
330   TensorReferenceVector referenced_tensors;
331   ctx->retrieve_accessed_tensors(&referenced_tensors);
332   EXPECT_EQ(1, referenced_tensors.size());
333   for (auto& ref : referenced_tensors) {
334     ref.Unref();
335   }
336 
337   delete ctx;
338   delete params.device;
339 }
340 
TEST_F(OpKernelTest,InputDtype)341 TEST_F(OpKernelTest, InputDtype) {
342   Env* env = Env::Default();
343   OpKernelContext::Params params;
344   params.record_tensor_accesses = false;
345   params.device = new DummyDevice(env, params.record_tensor_accesses);
346   Status status;
347   std::unique_ptr<OpKernel> op(
348       CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
349                      CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
350                      TF_GRAPH_DEF_VERSION, &status));
351   EXPECT_TRUE(status.ok());
352   params.op_kernel = op.get();
353   Tensor a(DT_FLOAT, TensorShape({}));
354   Tensor b(DT_INT32, TensorShape({}));
355   Tensor c(DT_UINT8, TensorShape({}));
356   gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&a), TensorValue(&b),
357                                             TensorValue(&c)};
358   params.inputs = &inputs;
359   OpKernelContext* ctx = new OpKernelContext(&params);
360 
361   DataType dtype;
362   EXPECT_FALSE(ctx->input_dtype("non_existent_input", &dtype).ok());
363   ASSERT_TRUE(ctx->input_dtype("a", &dtype).ok());
364   EXPECT_EQ(dtype, DT_FLOAT);
365   ASSERT_TRUE(ctx->input_dtype("b", &dtype).ok());
366   EXPECT_EQ(dtype, DT_INT32);
367   delete ctx;
368   delete params.device;
369 }
370 
371 class OpKernelBuilderTest : public ::testing::Test {
372  protected:
373   // Each attr is described by a "name|type|value".
CreateNodeDef(const string & op_type,const std::vector<string> & attrs)374   NodeDef CreateNodeDef(const string& op_type,
375                         const std::vector<string>& attrs) {
376     NodeDef node_def;
377     node_def.set_name(op_type + "-op");
378     node_def.set_op(op_type);
379     for (const string& attr_desc : attrs) {
380       std::vector<string> parts = str_util::Split(attr_desc, '|');
381       CHECK_EQ(parts.size(), 3);
382       AttrValue attr_value;
383       CHECK(ParseAttrValue(parts[1], parts[2], &attr_value)) << attr_desc;
384       node_def.mutable_attr()->insert(
385           AttrValueMap::value_type(parts[0], attr_value));
386     }
387     return node_def;
388   }
389 
ExpectSuccess(const string & op_type,const DeviceType & device_type,const std::vector<string> & attrs,DataTypeSlice input_types={})390   std::unique_ptr<OpKernel> ExpectSuccess(const string& op_type,
391                                           const DeviceType& device_type,
392                                           const std::vector<string>& attrs,
393                                           DataTypeSlice input_types = {}) {
394     Status status;
395     NodeDef def = CreateNodeDef(op_type, attrs);
396     for (size_t i = 0; i < input_types.size(); ++i) {
397       def.add_input("a:0");
398     }
399 
400     Env* env = Env::Default();
401     DeviceBase device(env);
402 
403     // Test CreateOpKernel()
404     std::unique_ptr<OpKernel> op(CreateOpKernel(device_type, &device,
405                                                 cpu_allocator(), def,
406                                                 TF_GRAPH_DEF_VERSION, &status));
407     EXPECT_TRUE(status.ok()) << status;
408     EXPECT_TRUE(op != nullptr);
409     if (op != nullptr) {
410       EXPECT_EQ(input_types.size(), op->num_inputs());
411       EXPECT_EQ(0, op->num_outputs());
412     }
413 
414     // Test SupportedDeviceTypesForNode()
415     DeviceTypeVector devices;
416     TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
417     bool found = false;
418     for (const DeviceType& dt : devices) {
419       if (dt == device_type) {
420         found = true;
421       }
422     }
423     EXPECT_TRUE(found) << "Missing " << device_type << " from "
424                        << devices.size() << " devices.";
425 
426     // In case the caller wants to use the OpKernel
427     return op;
428   }
429 
ExpectFailure(const string & op_type,const DeviceType & device_type,const std::vector<string> & attrs,error::Code code)430   void ExpectFailure(const string& op_type, const DeviceType& device_type,
431                      const std::vector<string>& attrs, error::Code code) {
432     Status status;
433     const NodeDef def = CreateNodeDef(op_type, attrs);
434     Env* env = Env::Default();
435     DeviceBase device(env);
436 
437     // Test CreateOpKernel().
438     std::unique_ptr<OpKernel> op(CreateOpKernel(device_type, &device,
439                                                 cpu_allocator(), def,
440                                                 TF_GRAPH_DEF_VERSION, &status));
441     EXPECT_TRUE(op == nullptr);
442     EXPECT_FALSE(status.ok());
443     if (!status.ok()) {
444       LOG(INFO) << "Status message: " << status.error_message();
445       EXPECT_EQ(code, status.code());
446 
447       // Test SupportedDeviceTypesForNode().
448       DeviceTypeVector devices;
449       if (errors::IsNotFound(status)) {
450         TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
451         for (const DeviceType& dt : devices) {
452           EXPECT_NE(dt, device_type);
453         }
454       } else {
455         Status status2 =
456             SupportedDeviceTypesForNode(DeviceTypes(), def, &devices);
457         EXPECT_EQ(status.code(), status2.code());
458       }
459     }
460   }
461 
GetKernelClassName(const string & op_type,const DeviceType & device_type,const std::vector<string> & attrs,DataTypeSlice input_types={})462   string GetKernelClassName(const string& op_type,
463                             const DeviceType& device_type,
464                             const std::vector<string>& attrs,
465                             DataTypeSlice input_types = {}) {
466     NodeDef def = CreateNodeDef(op_type, attrs);
467     for (size_t i = 0; i < input_types.size(); ++i) {
468       def.add_input("a:0");
469     }
470 
471     const KernelDef* kernel_def = nullptr;
472     string kernel_class_name;
473     const Status status =
474         FindKernelDef(device_type, def, &kernel_def, &kernel_class_name);
475     if (status.ok()) {
476       return kernel_class_name;
477     } else if (errors::IsNotFound(status)) {
478       return "not found";
479     } else {
480       return status.ToString();
481     }
482   }
483 };
484 
485 REGISTER_OP("BuildCPU");
486 REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel);
487 
TEST_F(OpKernelBuilderTest,BuilderCPU)488 TEST_F(OpKernelBuilderTest, BuilderCPU) {
489   ExpectSuccess("BuildCPU", DEVICE_CPU, {});
490   EXPECT_EQ("DummyKernel", GetKernelClassName("BuildCPU", DEVICE_CPU, {}));
491   ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND);
492   EXPECT_EQ("not found", GetKernelClassName("BuildCPU", DEVICE_GPU, {}));
493 }
494 
495 REGISTER_OP("BuildGPU");
496 REGISTER_KERNEL_BUILDER(Name("BuildGPU").Device(DEVICE_GPU), DummyKernel);
497 
TEST_F(OpKernelBuilderTest,BuilderGPU)498 TEST_F(OpKernelBuilderTest, BuilderGPU) {
499   ExpectFailure("BuildGPU", DEVICE_CPU, {}, error::NOT_FOUND);
500   ExpectSuccess("BuildGPU", DEVICE_GPU, {});
501 }
502 
503 REGISTER_OP("BuildBoth");
504 REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_CPU), DummyKernel);
505 REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_GPU), DummyKernel);
506 
TEST_F(OpKernelBuilderTest,BuilderBoth)507 TEST_F(OpKernelBuilderTest, BuilderBoth) {
508   ExpectSuccess("BuildBoth", DEVICE_CPU, {});
509   ExpectSuccess("BuildBoth", DEVICE_GPU, {});
510 }
511 
512 REGISTER_OP("BuildTypeAttr").Attr("T: type");
513 REGISTER_KERNEL_BUILDER(
514     Name("BuildTypeAttr").Device(DEVICE_CPU).TypeConstraint<float>("T"),
515     DummyKernel);
516 
TEST_F(OpKernelBuilderTest,BuilderTypeAttr)517 TEST_F(OpKernelBuilderTest, BuilderTypeAttr) {
518   ExpectSuccess("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_FLOAT"});
519   ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_BOOL"},
520                 error::NOT_FOUND);
521   ExpectFailure("BuildTypeAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
522   ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|int|7"},
523                 error::INVALID_ARGUMENT);
524 }
525 
526 REGISTER_OP("BuildTypeListAttr").Attr("T: list(type)");
527 REGISTER_KERNEL_BUILDER(
528     Name("BuildTypeListAttr").Device(DEVICE_CPU).TypeConstraint<bool>("T"),
529     DummyKernel);
530 
TEST_F(OpKernelBuilderTest,BuilderTypeListAttr)531 TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) {
532   ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"});
533   EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
534                                               {"T|list(type)|[]"}));
535 
536   ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"});
537   EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
538                                               {"T|list(type)|[]"}));
539 
540   ExpectSuccess("BuildTypeListAttr", DEVICE_CPU,
541                 {"T|list(type)|[DT_BOOL, DT_BOOL]"});
542 
543   ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"},
544                 error::NOT_FOUND);
545   EXPECT_EQ("not found", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
546                                             {"T|list(type)|[DT_FLOAT]"}));
547 
548   ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
549   EXPECT_TRUE(
550       StringPiece(GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {}))
551           .contains("Invalid argument: "));
552 
553   ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"},
554                 error::INVALID_ARGUMENT);
555 }
556 
557 REGISTER_OP("DuplicateKernel");
558 REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU),
559                         DummyKernel);
560 REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU),
561                         DummyKernel);
562 
TEST_F(OpKernelBuilderTest,DuplicateKernel)563 TEST_F(OpKernelBuilderTest, DuplicateKernel) {
564   const NodeDef ndef = CreateNodeDef("DuplicateKernel", {});
565   DeviceTypeVector devs;
566   Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
567   ASSERT_FALSE(status.ok());
568   EXPECT_TRUE(StringPiece(status.error_message())
569                   .contains("Multiple OpKernel registrations match NodeDef"));
570 
571   ExpectFailure("DuplicateKernel", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
572 }
573 
574 REGISTER_OP("DuplicateKernelForT").Attr("T: type");
575 REGISTER_KERNEL_BUILDER(
576     Name("DuplicateKernelForT").Device(DEVICE_CPU).TypeConstraint<float>("T"),
577     DummyKernel);
578 REGISTER_KERNEL_BUILDER(
579     Name("DuplicateKernelForT").Device(DEVICE_CPU).TypeConstraint<float>("T"),
580     DummyKernel);
581 
TEST_F(OpKernelBuilderTest,DuplicateKernelForT)582 TEST_F(OpKernelBuilderTest, DuplicateKernelForT) {
583   const NodeDef ndef =
584       CreateNodeDef("DuplicateKernelForT", {"T|type|DT_FLOAT"});
585   DeviceTypeVector devs;
586   Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
587   ASSERT_FALSE(status.ok());
588   EXPECT_TRUE(StringPiece(status.error_message())
589                   .contains("Multiple OpKernel registrations match NodeDef"));
590 
591   ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_FLOAT"},
592                 error::INVALID_ARGUMENT);
593   ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_BOOL"},
594                 error::NOT_FOUND);
595 }
596 
597 REGISTER_OP("BadConstraint").Attr("dtype: type");
598 REGISTER_KERNEL_BUILDER(Name("BadConstraint")
599                             .Device(DEVICE_CPU)
600                             // Mistake: "T" should be "dtype".
601                             .TypeConstraint<float>("T"),
602                         DummyKernel);
603 
TEST_F(OpKernelBuilderTest,BadConstraint)604 TEST_F(OpKernelBuilderTest, BadConstraint) {
605   const NodeDef ndef = CreateNodeDef("BadConstraint", {});
606   DeviceTypeVector devs;
607   Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
608   ASSERT_FALSE(status.ok());
609   EXPECT_TRUE(StringPiece(status.error_message())
610                   .contains("OpKernel 'BadConstraint' has constraint on attr "
611                             "'T' not in NodeDef"));
612 
613   ExpectFailure("BadConstraint", DEVICE_CPU, {"dtype|type|DT_FLOAT"},
614                 error::INVALID_ARGUMENT);
615 }
616 
617 REGISTER_OP("ListOut").Output("a: int32").Output("b: T").Attr("T: list(type)");
618 REGISTER_KERNEL_BUILDER(Name("ListOut").Device(tensorflow::DEVICE_CPU),
619                         DummyKernel);
620 
TEST_F(OpKernelBuilderTest,OpOutputList)621 TEST_F(OpKernelBuilderTest, OpOutputList) {
622   Env* env = Env::Default();
623   OpKernelContext::Params params;
624   params.record_tensor_accesses = false;
625   std::unique_ptr<DummyDevice> device(
626       new DummyDevice(env, params.record_tensor_accesses));
627   params.device = device.get();
628   Status status;
629   std::unique_ptr<OpKernel> op(CreateOpKernel(
630       DEVICE_CPU, params.device, cpu_allocator(),
631       CreateNodeDef("ListOut", {"T|list(type)|[DT_FLOAT, DT_INT32]"}),
632       TF_GRAPH_DEF_VERSION, &status));
633   EXPECT_TRUE(status.ok()) << status.ToString();
634   params.op_kernel = op.get();
635   gtl::InlinedVector<TensorValue, 4> inputs{};
636   params.inputs = &inputs;
637   std::unique_ptr<OpKernelContext> ctx(new OpKernelContext(&params));
638 
639   EXPECT_EQ(DT_INT32, ctx->expected_output_dtype(0));
640   OpOutputList out_list;
641   EXPECT_FALSE(ctx->output_list("non_existent_output", &out_list).ok());
642   ASSERT_TRUE(ctx->output_list("b", &out_list).ok());
643   EXPECT_EQ(DT_FLOAT, out_list.expected_output_dtype(0));
644   EXPECT_EQ(DT_INT32, out_list.expected_output_dtype(1));
645 }
646 
647 class GetAttrKernel : public ::tensorflow::OpKernel {
648  public:
GetAttrKernel(OpKernelConstruction * context)649   explicit GetAttrKernel(OpKernelConstruction* context) : OpKernel(context) {
650     string attr_name;
651     OP_REQUIRES_OK(context, context->GetAttr("attr_name", &attr_name));
652 
653     status.emplace_back("s", context->GetAttr(attr_name, &s));
654     status.emplace_back("s_list", context->GetAttr(attr_name, &s_list));
655     status.emplace_back("i", context->GetAttr(attr_name, &i));
656     status.emplace_back("i_list", context->GetAttr(attr_name, &i_list));
657     status.emplace_back("i32", context->GetAttr(attr_name, &i32));
658     status.emplace_back("i32_list", context->GetAttr(attr_name, &i32_list));
659     status.emplace_back("f", context->GetAttr(attr_name, &f));
660     status.emplace_back("f_list", context->GetAttr(attr_name, &f_list));
661     status.emplace_back("b", context->GetAttr(attr_name, &b));
662     status.emplace_back("b_list", context->GetAttr(attr_name, &b_list));
663     status.emplace_back("type", context->GetAttr(attr_name, &type));
664     status.emplace_back("type_list", context->GetAttr(attr_name, &type_list));
665     status.emplace_back("type_vector",
666                         context->GetAttr(attr_name, &type_vector));
667     status.emplace_back("shape_proto",
668                         context->GetAttr(attr_name, &shape_proto));
669     status.emplace_back("shape_proto_list",
670                         context->GetAttr(attr_name, &shape_proto_list));
671     status.emplace_back("shape", context->GetAttr(attr_name, &shape));
672     status.emplace_back("shape_list", context->GetAttr(attr_name, &shape_list));
673   }
Compute(::tensorflow::OpKernelContext * context)674   void Compute(::tensorflow::OpKernelContext* context) override {}
675 
ExpectOk(std::initializer_list<string> keys)676   void ExpectOk(std::initializer_list<string> keys) {
677     for (const auto& key_status : status) {
678       // Only the status for keys in "keys" should be ok().
679       bool in_keys = false;
680       for (const string& key : keys) {
681         if (key_status.first == key) {
682           in_keys = true;
683         }
684       }
685       EXPECT_EQ(in_keys, key_status.second.ok())
686           << "key_status: " << key_status.first << ", " << key_status.second;
687     }
688   }
689 
690   string s;
691   std::vector<string> s_list;
692   int64 i;
693   std::vector<int64> i_list;
694   int32 i32;
695   std::vector<int32> i32_list;
696   float f;
697   std::vector<float> f_list;
698   bool b;
699   std::vector<bool> b_list;
700   DataType type;
701   std::vector<DataType> type_list;
702   DataTypeVector type_vector;
703   TensorShapeProto shape_proto;
704   std::vector<TensorShapeProto> shape_proto_list;
705   TensorShape shape;
706   std::vector<TensorShape> shape_list;
707   std::vector<std::pair<string, Status>> status;
708 };
709 
710 class GetAttrTest : public OpKernelBuilderTest {};
711 
712 REGISTER_OP("GetAttrStringList")
713     .Attr("attr_name: string")
714     .Attr("a: list(string)");
715 REGISTER_KERNEL_BUILDER(Name("GetAttrStringList").Device(DEVICE_CPU),
716                         GetAttrKernel);
717 
TEST_F(GetAttrTest,StringList)718 TEST_F(GetAttrTest, StringList) {
719   std::unique_ptr<OpKernel> op_kernel =
720       ExpectSuccess("GetAttrStringList", DEVICE_CPU,
721                     {"attr_name|string|'a'", "a|list(string)|['foo', 'bar']"});
722   auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
723   get_attr_kernel->ExpectOk({"s_list"});
724   EXPECT_EQ(std::vector<string>({"foo", "bar"}), get_attr_kernel->s_list);
725 
726   op_kernel = ExpectSuccess("GetAttrStringList", DEVICE_CPU,
727                             {"attr_name|string|'b'", "a|list(string)|['baz']"});
728   get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
729   get_attr_kernel->ExpectOk({});
730   EXPECT_TRUE(get_attr_kernel->s_list.empty());
731 }
732 
733 REGISTER_OP("GetAttrInt")
734     .Attr("attr_name: string")
735     .Attr("a: int")
736     .Attr("b: list(int)");
737 REGISTER_KERNEL_BUILDER(Name("GetAttrInt").Device(DEVICE_CPU), GetAttrKernel);
738 
TEST_F(GetAttrTest,Int)739 TEST_F(GetAttrTest, Int) {
740   std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
741       "GetAttrInt", DEVICE_CPU,
742       {"attr_name|string|'a'", "a|int|35", "b|list(int)|[-1, 2, -4]"});
743   auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
744   get_attr_kernel->ExpectOk({"i", "i32"});
745   EXPECT_EQ(35, get_attr_kernel->i);
746   EXPECT_EQ(35, get_attr_kernel->i32);
747 
748   op_kernel = ExpectSuccess(
749       "GetAttrInt", DEVICE_CPU,
750       {"attr_name|string|'b'", "a|int|35", "b|list(int)|[-1, 2, -4]"});
751   get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
752   get_attr_kernel->ExpectOk({"i_list", "i32_list"});
753   EXPECT_EQ(std::vector<int64>({-1, 2, -4}), get_attr_kernel->i_list);
754   EXPECT_EQ(std::vector<int32>({-1, 2, -4}), get_attr_kernel->i32_list);
755 
756   // 8589934592 == 2^33, too big to fit in an int32
757   op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU,
758                             {"attr_name|string|'a'", "a|int|8589934592",
759                              "b|list(int)|[-8589934592]"});
760   get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
761   get_attr_kernel->ExpectOk({"i"});  // no i32
762   EXPECT_EQ(8589934592ll, get_attr_kernel->i);
763   for (const auto& key_status : get_attr_kernel->status) {
764     if (key_status.first == "i32") {
765       EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code());
766       EXPECT_EQ("Attr a has value 8589934592 out of range for an int32",
767                 key_status.second.error_message());
768     }
769   }
770 
771   op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU,
772                             {"attr_name|string|'b'", "a|int|8589934592",
773                              "b|list(int)|[-8589934592]"});
774   get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
775   get_attr_kernel->ExpectOk({"i_list"});  // no i32_list
776   EXPECT_EQ(std::vector<int64>({-8589934592ll}), get_attr_kernel->i_list);
777   for (const auto& key_status : get_attr_kernel->status) {
778     if (key_status.first == "i32_list") {
779       EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code());
780       EXPECT_EQ("Attr b has value -8589934592 out of range for an int32",
781                 key_status.second.error_message());
782     }
783   }
784 }
785 
786 REGISTER_OP("GetAttrShape")
787     .Attr("attr_name: string")
788     .Attr("a: shape")
789     .Attr("b: list(shape)");
790 REGISTER_KERNEL_BUILDER(Name("GetAttrShape").Device(DEVICE_CPU), GetAttrKernel);
791 
TEST_F(GetAttrTest,Shape)792 TEST_F(GetAttrTest, Shape) {
793   std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
794       "GetAttrShape", DEVICE_CPU,
795       {"attr_name|string|'a'", "a|shape|{ dim { size: 3 } }",
796        "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"});
797   auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
798   get_attr_kernel->ExpectOk({"shape", "shape_proto"});
799   EXPECT_EQ(get_attr_kernel->shape_proto.ShortDebugString(), "dim { size: 3 }");
800   EXPECT_EQ("[3]", get_attr_kernel->shape.DebugString());
801 
802   op_kernel = ExpectSuccess(
803       "GetAttrShape", DEVICE_CPU,
804       {"attr_name|string|'b'", "a|shape|{ dim { size: 3 } }",
805        "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"});
806   get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
807   get_attr_kernel->ExpectOk({"shape_list", "shape_proto_list"});
808   ASSERT_EQ(2, get_attr_kernel->shape_proto_list.size());
809   EXPECT_EQ(get_attr_kernel->shape_proto_list[0].ShortDebugString(),
810             "dim { size: 2 }");
811   EXPECT_EQ(get_attr_kernel->shape_proto_list[1].ShortDebugString(),
812             "dim { size: 4 }");
813   ASSERT_EQ(2, get_attr_kernel->shape_list.size());
814   EXPECT_EQ("[2]", get_attr_kernel->shape_list[0].DebugString());
815   EXPECT_EQ("[4]", get_attr_kernel->shape_list[1].DebugString());
816 }
817 
818 REGISTER_OP("GetAttrType").Attr("attr_name: string").Attr("a: type");
819 REGISTER_KERNEL_BUILDER(Name("GetAttrType").Device(DEVICE_CPU), GetAttrKernel);
820 
TEST_F(GetAttrTest,Type)821 TEST_F(GetAttrTest, Type) {
822   std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
823       "GetAttrType", DEVICE_CPU, {"attr_name|string|'a'", "a|type|DT_FLOAT"});
824   auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
825   get_attr_kernel->ExpectOk({"type"});
826   EXPECT_EQ(DT_FLOAT, get_attr_kernel->type);
827 }
828 
829 REGISTER_OP("GetAttrTypeList").Attr("attr_name: string").Attr("a: list(type)");
830 REGISTER_KERNEL_BUILDER(Name("GetAttrTypeList").Device(DEVICE_CPU),
831                         GetAttrKernel);
832 
TEST_F(GetAttrTest,TypeList)833 TEST_F(GetAttrTest, TypeList) {
834   std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
835       "GetAttrTypeList", DEVICE_CPU,
836       {"attr_name|string|'a'", "a|list(type)|[DT_INT32, DT_BOOL]"});
837   auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
838 
839   get_attr_kernel->ExpectOk({"type_list", "type_vector"});
840   ASSERT_EQ(2, get_attr_kernel->type_list.size());
841   EXPECT_EQ(DT_INT32, get_attr_kernel->type_list[0]);
842   EXPECT_EQ(DT_BOOL, get_attr_kernel->type_list[1]);
843   ASSERT_EQ(2, get_attr_kernel->type_vector.size());
844   EXPECT_EQ(DT_INT32, get_attr_kernel->type_vector[0]);
845   EXPECT_EQ(DT_BOOL, get_attr_kernel->type_vector[1]);
846 }
847 
848 class BaseKernel : public ::tensorflow::OpKernel {
849  public:
BaseKernel(OpKernelConstruction * context)850   explicit BaseKernel(OpKernelConstruction* context) : OpKernel(context) {}
Compute(::tensorflow::OpKernelContext * context)851   void Compute(::tensorflow::OpKernelContext* context) override {}
852   virtual int Which() const = 0;
853 };
854 
855 template <int WHICH>
856 class LabeledKernel : public BaseKernel {
857  public:
858   using BaseKernel::BaseKernel;
Which() const859   int Which() const override { return WHICH; }
860 };
861 
862 class LabelTest : public OpKernelBuilderTest {};
863 
864 REGISTER_OP("LabeledKernel");
865 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU),
866                         LabeledKernel<0>);
867 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("one"),
868                         LabeledKernel<1>);
869 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"),
870                         LabeledKernel<2>);
871 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"),
872                         LabeledKernel<3>);
873 
TEST_F(LabelTest,Default)874 TEST_F(LabelTest, Default) {
875   std::unique_ptr<OpKernel> op_kernel =
876       ExpectSuccess("LabeledKernel", DEVICE_CPU, {});
877   auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
878   EXPECT_EQ(0, get_labeled_kernel->Which());
879 
880   EXPECT_EQ("LabeledKernel<0>",
881             GetKernelClassName("LabeledKernel", DEVICE_CPU, {}));
882 }
883 
TEST_F(LabelTest,Specified)884 TEST_F(LabelTest, Specified) {
885   std::unique_ptr<OpKernel> op_kernel =
886       ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"});
887   auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
888   EXPECT_EQ(1, get_labeled_kernel->Which());
889   EXPECT_EQ("LabeledKernel<1>", GetKernelClassName("LabeledKernel", DEVICE_CPU,
890                                                    {"_kernel|string|'one'"}));
891 }
892 
TEST_F(LabelTest,Duplicate)893 TEST_F(LabelTest, Duplicate) {
894   ExpectFailure("LabeledKernel", DEVICE_CPU, {"_kernel|string|'dupe'"},
895                 error::INVALID_ARGUMENT);
896 }
897 
BM_InputRangeHelper(int iters,const NodeDef & node_def,const char * input_name,int expected_start,int expected_stop)898 void BM_InputRangeHelper(int iters, const NodeDef& node_def,
899                          const char* input_name, int expected_start,
900                          int expected_stop) {
901   Status status;
902   std::unique_ptr<DummyDevice> device(new DummyDevice(Env::Default(), false));
903 
904   std::unique_ptr<OpKernel> op(CreateOpKernel(DEVICE_CPU, device.get(),
905                                               cpu_allocator(), node_def,
906                                               TF_GRAPH_DEF_VERSION, &status));
907   TF_CHECK_OK(status);
908 
909   testing::StartTiming();
910   for (int i = 0; i < iters; ++i) {
911     int start;
912     int stop;
913     TF_CHECK_OK(op->InputRange(input_name, &start, &stop));
914     EXPECT_EQ(expected_start, start);
915     EXPECT_EQ(expected_stop, stop);
916   }
917   testing::StopTiming();
918 }
919 
920 REGISTER_KERNEL_BUILDER(Name("ConcatV2").Device(DEVICE_CPU), DummyKernel);
921 REGISTER_KERNEL_BUILDER(Name("Select").Device(DEVICE_CPU), DummyKernel);
922 
BM_ConcatInputRange(int iters)923 void BM_ConcatInputRange(int iters) {
924   testing::StopTiming();
925 
926   // Create a ConcatV2 NodeDef with 4 inputs (plus the axis).
927   NodeDef node_def;
928   node_def.set_name("concat-op");
929   node_def.set_op("ConcatV2");
930   AttrValue attr_N;
931   attr_N.set_i(4);
932   AttrValue attr_T;
933   attr_T.set_type(DT_FLOAT);
934   AttrValue attr_Tidx;
935   attr_Tidx.set_type(DT_INT32);
936   node_def.mutable_attr()->insert({"N", attr_N});
937   node_def.mutable_attr()->insert({"T", attr_T});
938   node_def.mutable_attr()->insert({"Tidx", attr_Tidx});
939   for (size_t i = 0; i < 5; ++i) {
940     node_def.add_input(strings::StrCat("a:", i));
941   }
942 
943   BM_InputRangeHelper(iters, node_def, "values", 0, 4);
944 }
945 
BM_SelectInputRange(int iters)946 void BM_SelectInputRange(int iters) {
947   testing::StopTiming();
948 
949   // Create a Select NodeDef with 3 inputs.
950   NodeDef node_def;
951   node_def.set_name("select-op");
952   node_def.set_op("Select");
953   AttrValue attr_T;
954   attr_T.set_type(DT_FLOAT);
955   node_def.mutable_attr()->insert({"T", attr_T});
956   for (size_t i = 0; i < 3; ++i) {
957     node_def.add_input(strings::StrCat("a:", i));
958   }
959 
960   BM_InputRangeHelper(iters, node_def, "condition", 0, 1);
961 }
962 
963 BENCHMARK(BM_ConcatInputRange);
964 BENCHMARK(BM_SelectInputRange);
965 
966 }  // namespace
967 }  // namespace tensorflow
968