1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_kernel.h"
17
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 #include "tensorflow/core/framework/allocator.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/fake_input.h"
25 #include "tensorflow/core/framework/node_def_builder.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/tensor_shape.pb.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/protobuf.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/platform/test_benchmark.h"
37 #include "tensorflow/core/public/version.h"
38
39 class DummyKernel : public tensorflow::OpKernel {
40 public:
DummyKernel(tensorflow::OpKernelConstruction * context)41 explicit DummyKernel(tensorflow::OpKernelConstruction* context)
42 : OpKernel(context) {}
Compute(tensorflow::OpKernelContext * context)43 void Compute(tensorflow::OpKernelContext* context) override {}
44 };
45
46 // Test that registration works outside a namespace.
47 REGISTER_OP("Test1").Input("a: float").Input("b: int32").Output("o: uint8");
48 REGISTER_KERNEL_BUILDER(Name("Test1").Device(tensorflow::DEVICE_CPU),
49 DummyKernel);
50
51 namespace foo {
52 bool match_signature_ = false;
53
54 // Test that registration works inside a different namespace.
55 class TestOp2 : public ::tensorflow::OpKernel {
56 public:
TestOp2(::tensorflow::OpKernelConstruction * context)57 explicit TestOp2(::tensorflow::OpKernelConstruction* context)
58 : OpKernel(context) {
59 ::tensorflow::Status status = context->MatchSignature(
60 {::tensorflow::DT_INT32}, {::tensorflow::DT_INT32});
61 match_signature_ = status.ok();
62 context->SetStatus(status);
63 }
Compute(::tensorflow::OpKernelContext * context)64 void Compute(::tensorflow::OpKernelContext* context) override {}
65 };
66
67 REGISTER_OP("Test2").Input("i: T").Output("o: T").Attr("T: type");
68 REGISTER_KERNEL_BUILDER(Name("Test2")
69 .Device(::tensorflow::DEVICE_GPU)
70 .HostMemory("i")
71 .HostMemory("o"),
72 TestOp2);
73 } // namespace foo
74
75 namespace tensorflow {
76
77 // Two operations with the same name but different devices.
78 REGISTER_OP("Test3").Input("a: T").Input("b: T").Attr("T: type");
79
80 class TestOp3Cpu : public tensorflow::OpKernel {
81 public:
TestOp3Cpu(OpKernelConstruction * context)82 explicit TestOp3Cpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)83 void Compute(OpKernelContext* context) override {}
84 };
85
86 REGISTER_KERNEL_BUILDER(
87 Name("Test3").Device(DEVICE_CPU).TypeConstraint<int8>("T"), TestOp3Cpu);
88
89 namespace {
90
91 class TestOp3Gpu : public tensorflow::OpKernel {
92 public:
TestOp3Gpu(OpKernelConstruction * context)93 explicit TestOp3Gpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)94 void Compute(OpKernelContext* context) override {}
95 };
96
97 REGISTER_KERNEL_BUILDER(
98 Name("Test3").Device(DEVICE_GPU).TypeConstraint<float>("T"), TestOp3Cpu);
99
100 // An Op registered for both
101 REGISTER_OP("Test4").Input("i: float").Output("o: float");
102 REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_CPU), DummyKernel);
103 REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_GPU), DummyKernel);
104
DeviceTypes()105 static std::vector<DeviceType> DeviceTypes() {
106 return {DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)};
107 }
108
109 class OpKernelTest : public ::testing::Test {
110 public:
OpKernelTest()111 OpKernelTest() : device_(Env::Default()) {}
112
113 protected:
CreateNodeDef(const string & op_type,const DataTypeVector & inputs)114 NodeDef CreateNodeDef(const string& op_type, const DataTypeVector& inputs) {
115 NodeDefBuilder builder(op_type + "-op", op_type);
116 for (DataType dt : inputs) {
117 builder.Input(FakeInput(dt));
118 }
119 NodeDef node_def;
120 TF_CHECK_OK(builder.Finalize(&node_def));
121 return node_def;
122 }
123
ExpectEqual(const string & what,const DataTypeVector & expected,const DataTypeVector & observed)124 void ExpectEqual(const string& what, const DataTypeVector& expected,
125 const DataTypeVector& observed) {
126 EXPECT_EQ(expected.size(), observed.size()) << what;
127 const size_t size = std::min(expected.size(), observed.size());
128 for (size_t i = 0; i < size; ++i) {
129 bool match = TypesCompatible(expected[i], observed[i]);
130 EXPECT_TRUE(match) << what << " i:" << i << ", expected: " << expected[i]
131 << ", observed: " << observed[i];
132 }
133 }
134
ExpectSuccess(const string & op_type,DeviceType device_type,const DataTypeVector & inputs,const DataTypeVector & outputs)135 void ExpectSuccess(const string& op_type, DeviceType device_type,
136 const DataTypeVector& inputs,
137 const DataTypeVector& outputs) {
138 Status status;
139 std::unique_ptr<OpKernel> op(CreateOpKernel(
140 std::move(device_type), &device_, cpu_allocator(),
141 CreateNodeDef(op_type, inputs), TF_GRAPH_DEF_VERSION, &status));
142 EXPECT_TRUE(status.ok()) << status;
143 EXPECT_TRUE(op != nullptr);
144 if (op != nullptr) {
145 ExpectEqual("inputs", op->input_types(), inputs);
146 ExpectEqual("outputs", op->output_types(), outputs);
147 }
148 }
149
ExpectFailure(const string & ascii_node_def,DeviceType device_type,error::Code code)150 void ExpectFailure(const string& ascii_node_def, DeviceType device_type,
151 error::Code code) {
152 NodeDef node_def;
153 protobuf::TextFormat::ParseFromString(ascii_node_def, &node_def);
154 Status status;
155 std::unique_ptr<OpKernel> op(
156 CreateOpKernel(std::move(device_type), &device_, cpu_allocator(),
157 node_def, TF_GRAPH_DEF_VERSION, &status));
158 EXPECT_TRUE(op == nullptr);
159 EXPECT_FALSE(status.ok());
160 if (!status.ok()) {
161 LOG(INFO) << "Status message: " << status.error_message();
162 EXPECT_EQ(code, status.code());
163 }
164 }
165
166 private:
167 DeviceBase device_;
168 };
169
TEST_F(OpKernelTest,SuccessCpu)170 TEST_F(OpKernelTest, SuccessCpu) {
171 ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT, DT_INT32}, {DT_UINT8});
172 ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT_REF, DT_INT32}, {DT_UINT8});
173 }
174
TEST_F(OpKernelTest,SuccessGpu)175 TEST_F(OpKernelTest, SuccessGpu) {
176 foo::match_signature_ = false;
177 ExpectSuccess("Test2", DEVICE_GPU, {DT_INT32}, {DT_INT32});
178 EXPECT_TRUE(foo::match_signature_);
179 }
180
TEST_F(OpKernelTest,SuccessBothCpuAndGpu)181 TEST_F(OpKernelTest, SuccessBothCpuAndGpu) {
182 ExpectSuccess("Test3", DEVICE_CPU, {DT_INT8, DT_INT8}, {});
183 ExpectSuccess("Test3", DEVICE_GPU, {DT_FLOAT, DT_FLOAT}, {});
184 }
185
TEST_F(OpKernelTest,CpuTypeRegistered)186 TEST_F(OpKernelTest, CpuTypeRegistered) {
187 NodeDef ndef = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
188 DeviceTypeVector devs;
189 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
190 EXPECT_EQ(1, devs.size());
191 EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]);
192 }
193
TEST_F(OpKernelTest,CpuAndGpuTypeRegistered)194 TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) {
195 {
196 // Try a node def of an op that is registered for a specific type
197 // only on CPU.
198 NodeDef ndef = CreateNodeDef("Test3", {DT_INT8, DT_INT8});
199 DeviceTypeVector devs;
200 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
201 EXPECT_EQ(1, devs.size());
202 EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]);
203 }
204 {
205 // Try a node def of an op that is registered for a specific type
206 // only on GPU.
207 NodeDef ndef = CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT});
208 DeviceTypeVector devs;
209 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
210 EXPECT_EQ(1, devs.size());
211 EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]);
212 }
213 {
214 // Try a node def of an op that is only registered for other types.
215 NodeDef ndef = CreateNodeDef("Test3", {DT_STRING, DT_STRING});
216 DeviceTypeVector devs;
217 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
218 EXPECT_EQ(0, devs.size());
219 }
220
221 {
222 // Try a node def of an op that is registered for both.
223 NodeDef ndef = CreateNodeDef("Test4", {DT_FLOAT});
224 DeviceTypeVector devs;
225 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
226 EXPECT_EQ(2, devs.size());
227 EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]);
228 EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1]);
229 }
230 }
231
TEST_F(OpKernelTest,NotFound)232 TEST_F(OpKernelTest, NotFound) {
233 const auto not_found = error::NOT_FOUND;
234 // Something with that op type name exists, but only with a
235 // different DeviceType.
236 ExpectFailure(CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}).DebugString(),
237 DEVICE_GPU, not_found);
238 ExpectFailure(CreateNodeDef("Test3", {DT_INT8, DT_INT8}).DebugString(),
239 DEVICE_GPU, not_found);
240 ExpectFailure(CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}).DebugString(),
241 DEVICE_CPU, not_found);
242
243 // No kernel with that signature registered.
244 ExpectFailure(CreateNodeDef("Test3", {DT_INT32, DT_INT32}).DebugString(),
245 DEVICE_GPU, not_found);
246
247 // Nothing with that op type name exists.
248 ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_CPU, not_found);
249 ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_GPU, not_found);
250 }
251
TEST_F(OpKernelTest,TooFewInputs)252 TEST_F(OpKernelTest, TooFewInputs) {
253 const auto invalid = error::INVALID_ARGUMENT;
254 NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
255 node_def.clear_input();
256 ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
257 node_def.add_input("a");
258 ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
259 }
260
TEST_F(OpKernelTest,TooManyInputs)261 TEST_F(OpKernelTest, TooManyInputs) {
262 const auto invalid = error::INVALID_ARGUMENT;
263 NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
264 node_def.add_input("c");
265 ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
266 }
267
TEST_F(OpKernelTest,MatchSignatureFailes)268 TEST_F(OpKernelTest, MatchSignatureFailes) {
269 const auto invalid = error::INVALID_ARGUMENT;
270 foo::match_signature_ = true;
271 ExpectFailure(CreateNodeDef("Test2", {DT_FLOAT}).DebugString(), DEVICE_GPU,
272 invalid);
273 EXPECT_FALSE(foo::match_signature_);
274 }
275
276 class DummyDevice : public DeviceBase {
277 public:
DummyDevice(Env * env,bool save)278 DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
RequiresRecordingAccessedTensors() const279 bool RequiresRecordingAccessedTensors() const override { return save_; }
GetAllocator(AllocatorAttributes)280 Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
281 return cpu_allocator();
282 }
283
284 private:
285 bool save_;
286 };
287
TEST_F(OpKernelTest,SaveTempFalse)288 TEST_F(OpKernelTest, SaveTempFalse) {
289 Env* env = Env::Default();
290 OpKernelContext::Params params;
291 params.record_tensor_accesses = false;
292 params.device = new DummyDevice(env, params.record_tensor_accesses);
293 Status status;
294 std::unique_ptr<OpKernel> op(
295 CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
296 CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
297 TF_GRAPH_DEF_VERSION, &status));
298 EXPECT_TRUE(status.ok());
299 params.op_kernel = op.get();
300 OpKernelContext* ctx = new OpKernelContext(¶ms);
301
302 Tensor t;
303 TF_EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t));
304
305 TensorReferenceVector referenced_tensors;
306 ctx->retrieve_accessed_tensors(&referenced_tensors);
307 EXPECT_EQ(0, referenced_tensors.size());
308
309 delete ctx;
310 delete params.device;
311 }
312
TEST_F(OpKernelTest,SaveTempTrue)313 TEST_F(OpKernelTest, SaveTempTrue) {
314 Env* env = Env::Default();
315 OpKernelContext::Params params;
316 params.record_tensor_accesses = true;
317 params.device = new DummyDevice(env, params.record_tensor_accesses);
318 Status status;
319 std::unique_ptr<OpKernel> op(
320 CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
321 CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
322 TF_GRAPH_DEF_VERSION, &status));
323 EXPECT_TRUE(status.ok());
324 params.op_kernel = op.get();
325 OpKernelContext* ctx = new OpKernelContext(¶ms);
326
327 Tensor t;
328 TF_EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t));
329
330 TensorReferenceVector referenced_tensors;
331 ctx->retrieve_accessed_tensors(&referenced_tensors);
332 EXPECT_EQ(1, referenced_tensors.size());
333 for (auto& ref : referenced_tensors) {
334 ref.Unref();
335 }
336
337 delete ctx;
338 delete params.device;
339 }
340
TEST_F(OpKernelTest,InputDtype)341 TEST_F(OpKernelTest, InputDtype) {
342 Env* env = Env::Default();
343 OpKernelContext::Params params;
344 params.record_tensor_accesses = false;
345 params.device = new DummyDevice(env, params.record_tensor_accesses);
346 Status status;
347 std::unique_ptr<OpKernel> op(
348 CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
349 CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
350 TF_GRAPH_DEF_VERSION, &status));
351 EXPECT_TRUE(status.ok());
352 params.op_kernel = op.get();
353 Tensor a(DT_FLOAT, TensorShape({}));
354 Tensor b(DT_INT32, TensorShape({}));
355 Tensor c(DT_UINT8, TensorShape({}));
356 gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&a), TensorValue(&b),
357 TensorValue(&c)};
358 params.inputs = &inputs;
359 OpKernelContext* ctx = new OpKernelContext(¶ms);
360
361 DataType dtype;
362 EXPECT_FALSE(ctx->input_dtype("non_existent_input", &dtype).ok());
363 ASSERT_TRUE(ctx->input_dtype("a", &dtype).ok());
364 EXPECT_EQ(dtype, DT_FLOAT);
365 ASSERT_TRUE(ctx->input_dtype("b", &dtype).ok());
366 EXPECT_EQ(dtype, DT_INT32);
367 delete ctx;
368 delete params.device;
369 }
370
371 class OpKernelBuilderTest : public ::testing::Test {
372 protected:
373 // Each attr is described by a "name|type|value".
CreateNodeDef(const string & op_type,const std::vector<string> & attrs)374 NodeDef CreateNodeDef(const string& op_type,
375 const std::vector<string>& attrs) {
376 NodeDef node_def;
377 node_def.set_name(op_type + "-op");
378 node_def.set_op(op_type);
379 for (const string& attr_desc : attrs) {
380 std::vector<string> parts = str_util::Split(attr_desc, '|');
381 CHECK_EQ(parts.size(), 3);
382 AttrValue attr_value;
383 CHECK(ParseAttrValue(parts[1], parts[2], &attr_value)) << attr_desc;
384 node_def.mutable_attr()->insert(
385 AttrValueMap::value_type(parts[0], attr_value));
386 }
387 return node_def;
388 }
389
ExpectSuccess(const string & op_type,const DeviceType & device_type,const std::vector<string> & attrs,DataTypeSlice input_types={})390 std::unique_ptr<OpKernel> ExpectSuccess(const string& op_type,
391 const DeviceType& device_type,
392 const std::vector<string>& attrs,
393 DataTypeSlice input_types = {}) {
394 Status status;
395 NodeDef def = CreateNodeDef(op_type, attrs);
396 for (size_t i = 0; i < input_types.size(); ++i) {
397 def.add_input("a:0");
398 }
399
400 Env* env = Env::Default();
401 DeviceBase device(env);
402
403 // Test CreateOpKernel()
404 std::unique_ptr<OpKernel> op(CreateOpKernel(device_type, &device,
405 cpu_allocator(), def,
406 TF_GRAPH_DEF_VERSION, &status));
407 EXPECT_TRUE(status.ok()) << status;
408 EXPECT_TRUE(op != nullptr);
409 if (op != nullptr) {
410 EXPECT_EQ(input_types.size(), op->num_inputs());
411 EXPECT_EQ(0, op->num_outputs());
412 }
413
414 // Test SupportedDeviceTypesForNode()
415 DeviceTypeVector devices;
416 TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
417 bool found = false;
418 for (const DeviceType& dt : devices) {
419 if (dt == device_type) {
420 found = true;
421 }
422 }
423 EXPECT_TRUE(found) << "Missing " << device_type << " from "
424 << devices.size() << " devices.";
425
426 // In case the caller wants to use the OpKernel
427 return op;
428 }
429
ExpectFailure(const string & op_type,const DeviceType & device_type,const std::vector<string> & attrs,error::Code code)430 void ExpectFailure(const string& op_type, const DeviceType& device_type,
431 const std::vector<string>& attrs, error::Code code) {
432 Status status;
433 const NodeDef def = CreateNodeDef(op_type, attrs);
434 Env* env = Env::Default();
435 DeviceBase device(env);
436
437 // Test CreateOpKernel().
438 std::unique_ptr<OpKernel> op(CreateOpKernel(device_type, &device,
439 cpu_allocator(), def,
440 TF_GRAPH_DEF_VERSION, &status));
441 EXPECT_TRUE(op == nullptr);
442 EXPECT_FALSE(status.ok());
443 if (!status.ok()) {
444 LOG(INFO) << "Status message: " << status.error_message();
445 EXPECT_EQ(code, status.code());
446
447 // Test SupportedDeviceTypesForNode().
448 DeviceTypeVector devices;
449 if (errors::IsNotFound(status)) {
450 TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
451 for (const DeviceType& dt : devices) {
452 EXPECT_NE(dt, device_type);
453 }
454 } else {
455 Status status2 =
456 SupportedDeviceTypesForNode(DeviceTypes(), def, &devices);
457 EXPECT_EQ(status.code(), status2.code());
458 }
459 }
460 }
461
GetKernelClassName(const string & op_type,const DeviceType & device_type,const std::vector<string> & attrs,DataTypeSlice input_types={})462 string GetKernelClassName(const string& op_type,
463 const DeviceType& device_type,
464 const std::vector<string>& attrs,
465 DataTypeSlice input_types = {}) {
466 NodeDef def = CreateNodeDef(op_type, attrs);
467 for (size_t i = 0; i < input_types.size(); ++i) {
468 def.add_input("a:0");
469 }
470
471 const KernelDef* kernel_def = nullptr;
472 string kernel_class_name;
473 const Status status =
474 FindKernelDef(device_type, def, &kernel_def, &kernel_class_name);
475 if (status.ok()) {
476 return kernel_class_name;
477 } else if (errors::IsNotFound(status)) {
478 return "not found";
479 } else {
480 return status.ToString();
481 }
482 }
483 };
484
485 REGISTER_OP("BuildCPU");
486 REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel);
487
TEST_F(OpKernelBuilderTest,BuilderCPU)488 TEST_F(OpKernelBuilderTest, BuilderCPU) {
489 ExpectSuccess("BuildCPU", DEVICE_CPU, {});
490 EXPECT_EQ("DummyKernel", GetKernelClassName("BuildCPU", DEVICE_CPU, {}));
491 ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND);
492 EXPECT_EQ("not found", GetKernelClassName("BuildCPU", DEVICE_GPU, {}));
493 }
494
495 REGISTER_OP("BuildGPU");
496 REGISTER_KERNEL_BUILDER(Name("BuildGPU").Device(DEVICE_GPU), DummyKernel);
497
TEST_F(OpKernelBuilderTest,BuilderGPU)498 TEST_F(OpKernelBuilderTest, BuilderGPU) {
499 ExpectFailure("BuildGPU", DEVICE_CPU, {}, error::NOT_FOUND);
500 ExpectSuccess("BuildGPU", DEVICE_GPU, {});
501 }
502
503 REGISTER_OP("BuildBoth");
504 REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_CPU), DummyKernel);
505 REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_GPU), DummyKernel);
506
TEST_F(OpKernelBuilderTest,BuilderBoth)507 TEST_F(OpKernelBuilderTest, BuilderBoth) {
508 ExpectSuccess("BuildBoth", DEVICE_CPU, {});
509 ExpectSuccess("BuildBoth", DEVICE_GPU, {});
510 }
511
512 REGISTER_OP("BuildTypeAttr").Attr("T: type");
513 REGISTER_KERNEL_BUILDER(
514 Name("BuildTypeAttr").Device(DEVICE_CPU).TypeConstraint<float>("T"),
515 DummyKernel);
516
TEST_F(OpKernelBuilderTest,BuilderTypeAttr)517 TEST_F(OpKernelBuilderTest, BuilderTypeAttr) {
518 ExpectSuccess("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_FLOAT"});
519 ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_BOOL"},
520 error::NOT_FOUND);
521 ExpectFailure("BuildTypeAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
522 ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|int|7"},
523 error::INVALID_ARGUMENT);
524 }
525
526 REGISTER_OP("BuildTypeListAttr").Attr("T: list(type)");
527 REGISTER_KERNEL_BUILDER(
528 Name("BuildTypeListAttr").Device(DEVICE_CPU).TypeConstraint<bool>("T"),
529 DummyKernel);
530
TEST_F(OpKernelBuilderTest,BuilderTypeListAttr)531 TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) {
532 ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"});
533 EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
534 {"T|list(type)|[]"}));
535
536 ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"});
537 EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
538 {"T|list(type)|[]"}));
539
540 ExpectSuccess("BuildTypeListAttr", DEVICE_CPU,
541 {"T|list(type)|[DT_BOOL, DT_BOOL]"});
542
543 ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"},
544 error::NOT_FOUND);
545 EXPECT_EQ("not found", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
546 {"T|list(type)|[DT_FLOAT]"}));
547
548 ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
549 EXPECT_TRUE(
550 StringPiece(GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {}))
551 .contains("Invalid argument: "));
552
553 ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"},
554 error::INVALID_ARGUMENT);
555 }
556
557 REGISTER_OP("DuplicateKernel");
558 REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU),
559 DummyKernel);
560 REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU),
561 DummyKernel);
562
TEST_F(OpKernelBuilderTest,DuplicateKernel)563 TEST_F(OpKernelBuilderTest, DuplicateKernel) {
564 const NodeDef ndef = CreateNodeDef("DuplicateKernel", {});
565 DeviceTypeVector devs;
566 Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
567 ASSERT_FALSE(status.ok());
568 EXPECT_TRUE(StringPiece(status.error_message())
569 .contains("Multiple OpKernel registrations match NodeDef"));
570
571 ExpectFailure("DuplicateKernel", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
572 }
573
574 REGISTER_OP("DuplicateKernelForT").Attr("T: type");
575 REGISTER_KERNEL_BUILDER(
576 Name("DuplicateKernelForT").Device(DEVICE_CPU).TypeConstraint<float>("T"),
577 DummyKernel);
578 REGISTER_KERNEL_BUILDER(
579 Name("DuplicateKernelForT").Device(DEVICE_CPU).TypeConstraint<float>("T"),
580 DummyKernel);
581
TEST_F(OpKernelBuilderTest,DuplicateKernelForT)582 TEST_F(OpKernelBuilderTest, DuplicateKernelForT) {
583 const NodeDef ndef =
584 CreateNodeDef("DuplicateKernelForT", {"T|type|DT_FLOAT"});
585 DeviceTypeVector devs;
586 Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
587 ASSERT_FALSE(status.ok());
588 EXPECT_TRUE(StringPiece(status.error_message())
589 .contains("Multiple OpKernel registrations match NodeDef"));
590
591 ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_FLOAT"},
592 error::INVALID_ARGUMENT);
593 ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_BOOL"},
594 error::NOT_FOUND);
595 }
596
597 REGISTER_OP("BadConstraint").Attr("dtype: type");
598 REGISTER_KERNEL_BUILDER(Name("BadConstraint")
599 .Device(DEVICE_CPU)
600 // Mistake: "T" should be "dtype".
601 .TypeConstraint<float>("T"),
602 DummyKernel);
603
TEST_F(OpKernelBuilderTest,BadConstraint)604 TEST_F(OpKernelBuilderTest, BadConstraint) {
605 const NodeDef ndef = CreateNodeDef("BadConstraint", {});
606 DeviceTypeVector devs;
607 Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
608 ASSERT_FALSE(status.ok());
609 EXPECT_TRUE(StringPiece(status.error_message())
610 .contains("OpKernel 'BadConstraint' has constraint on attr "
611 "'T' not in NodeDef"));
612
613 ExpectFailure("BadConstraint", DEVICE_CPU, {"dtype|type|DT_FLOAT"},
614 error::INVALID_ARGUMENT);
615 }
616
617 REGISTER_OP("ListOut").Output("a: int32").Output("b: T").Attr("T: list(type)");
618 REGISTER_KERNEL_BUILDER(Name("ListOut").Device(tensorflow::DEVICE_CPU),
619 DummyKernel);
620
TEST_F(OpKernelBuilderTest,OpOutputList)621 TEST_F(OpKernelBuilderTest, OpOutputList) {
622 Env* env = Env::Default();
623 OpKernelContext::Params params;
624 params.record_tensor_accesses = false;
625 std::unique_ptr<DummyDevice> device(
626 new DummyDevice(env, params.record_tensor_accesses));
627 params.device = device.get();
628 Status status;
629 std::unique_ptr<OpKernel> op(CreateOpKernel(
630 DEVICE_CPU, params.device, cpu_allocator(),
631 CreateNodeDef("ListOut", {"T|list(type)|[DT_FLOAT, DT_INT32]"}),
632 TF_GRAPH_DEF_VERSION, &status));
633 EXPECT_TRUE(status.ok()) << status.ToString();
634 params.op_kernel = op.get();
635 gtl::InlinedVector<TensorValue, 4> inputs{};
636 params.inputs = &inputs;
637 std::unique_ptr<OpKernelContext> ctx(new OpKernelContext(¶ms));
638
639 EXPECT_EQ(DT_INT32, ctx->expected_output_dtype(0));
640 OpOutputList out_list;
641 EXPECT_FALSE(ctx->output_list("non_existent_output", &out_list).ok());
642 ASSERT_TRUE(ctx->output_list("b", &out_list).ok());
643 EXPECT_EQ(DT_FLOAT, out_list.expected_output_dtype(0));
644 EXPECT_EQ(DT_INT32, out_list.expected_output_dtype(1));
645 }
646
647 class GetAttrKernel : public ::tensorflow::OpKernel {
648 public:
GetAttrKernel(OpKernelConstruction * context)649 explicit GetAttrKernel(OpKernelConstruction* context) : OpKernel(context) {
650 string attr_name;
651 OP_REQUIRES_OK(context, context->GetAttr("attr_name", &attr_name));
652
653 status.emplace_back("s", context->GetAttr(attr_name, &s));
654 status.emplace_back("s_list", context->GetAttr(attr_name, &s_list));
655 status.emplace_back("i", context->GetAttr(attr_name, &i));
656 status.emplace_back("i_list", context->GetAttr(attr_name, &i_list));
657 status.emplace_back("i32", context->GetAttr(attr_name, &i32));
658 status.emplace_back("i32_list", context->GetAttr(attr_name, &i32_list));
659 status.emplace_back("f", context->GetAttr(attr_name, &f));
660 status.emplace_back("f_list", context->GetAttr(attr_name, &f_list));
661 status.emplace_back("b", context->GetAttr(attr_name, &b));
662 status.emplace_back("b_list", context->GetAttr(attr_name, &b_list));
663 status.emplace_back("type", context->GetAttr(attr_name, &type));
664 status.emplace_back("type_list", context->GetAttr(attr_name, &type_list));
665 status.emplace_back("type_vector",
666 context->GetAttr(attr_name, &type_vector));
667 status.emplace_back("shape_proto",
668 context->GetAttr(attr_name, &shape_proto));
669 status.emplace_back("shape_proto_list",
670 context->GetAttr(attr_name, &shape_proto_list));
671 status.emplace_back("shape", context->GetAttr(attr_name, &shape));
672 status.emplace_back("shape_list", context->GetAttr(attr_name, &shape_list));
673 }
Compute(::tensorflow::OpKernelContext * context)674 void Compute(::tensorflow::OpKernelContext* context) override {}
675
ExpectOk(std::initializer_list<string> keys)676 void ExpectOk(std::initializer_list<string> keys) {
677 for (const auto& key_status : status) {
678 // Only the status for keys in "keys" should be ok().
679 bool in_keys = false;
680 for (const string& key : keys) {
681 if (key_status.first == key) {
682 in_keys = true;
683 }
684 }
685 EXPECT_EQ(in_keys, key_status.second.ok())
686 << "key_status: " << key_status.first << ", " << key_status.second;
687 }
688 }
689
690 string s;
691 std::vector<string> s_list;
692 int64 i;
693 std::vector<int64> i_list;
694 int32 i32;
695 std::vector<int32> i32_list;
696 float f;
697 std::vector<float> f_list;
698 bool b;
699 std::vector<bool> b_list;
700 DataType type;
701 std::vector<DataType> type_list;
702 DataTypeVector type_vector;
703 TensorShapeProto shape_proto;
704 std::vector<TensorShapeProto> shape_proto_list;
705 TensorShape shape;
706 std::vector<TensorShape> shape_list;
707 std::vector<std::pair<string, Status>> status;
708 };
709
710 class GetAttrTest : public OpKernelBuilderTest {};
711
712 REGISTER_OP("GetAttrStringList")
713 .Attr("attr_name: string")
714 .Attr("a: list(string)");
715 REGISTER_KERNEL_BUILDER(Name("GetAttrStringList").Device(DEVICE_CPU),
716 GetAttrKernel);
717
TEST_F(GetAttrTest,StringList)718 TEST_F(GetAttrTest, StringList) {
719 std::unique_ptr<OpKernel> op_kernel =
720 ExpectSuccess("GetAttrStringList", DEVICE_CPU,
721 {"attr_name|string|'a'", "a|list(string)|['foo', 'bar']"});
722 auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
723 get_attr_kernel->ExpectOk({"s_list"});
724 EXPECT_EQ(std::vector<string>({"foo", "bar"}), get_attr_kernel->s_list);
725
726 op_kernel = ExpectSuccess("GetAttrStringList", DEVICE_CPU,
727 {"attr_name|string|'b'", "a|list(string)|['baz']"});
728 get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
729 get_attr_kernel->ExpectOk({});
730 EXPECT_TRUE(get_attr_kernel->s_list.empty());
731 }
732
733 REGISTER_OP("GetAttrInt")
734 .Attr("attr_name: string")
735 .Attr("a: int")
736 .Attr("b: list(int)");
737 REGISTER_KERNEL_BUILDER(Name("GetAttrInt").Device(DEVICE_CPU), GetAttrKernel);
738
TEST_F(GetAttrTest,Int)739 TEST_F(GetAttrTest, Int) {
740 std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
741 "GetAttrInt", DEVICE_CPU,
742 {"attr_name|string|'a'", "a|int|35", "b|list(int)|[-1, 2, -4]"});
743 auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
744 get_attr_kernel->ExpectOk({"i", "i32"});
745 EXPECT_EQ(35, get_attr_kernel->i);
746 EXPECT_EQ(35, get_attr_kernel->i32);
747
748 op_kernel = ExpectSuccess(
749 "GetAttrInt", DEVICE_CPU,
750 {"attr_name|string|'b'", "a|int|35", "b|list(int)|[-1, 2, -4]"});
751 get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
752 get_attr_kernel->ExpectOk({"i_list", "i32_list"});
753 EXPECT_EQ(std::vector<int64>({-1, 2, -4}), get_attr_kernel->i_list);
754 EXPECT_EQ(std::vector<int32>({-1, 2, -4}), get_attr_kernel->i32_list);
755
756 // 8589934592 == 2^33, too big to fit in an int32
757 op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU,
758 {"attr_name|string|'a'", "a|int|8589934592",
759 "b|list(int)|[-8589934592]"});
760 get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
761 get_attr_kernel->ExpectOk({"i"}); // no i32
762 EXPECT_EQ(8589934592ll, get_attr_kernel->i);
763 for (const auto& key_status : get_attr_kernel->status) {
764 if (key_status.first == "i32") {
765 EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code());
766 EXPECT_EQ("Attr a has value 8589934592 out of range for an int32",
767 key_status.second.error_message());
768 }
769 }
770
771 op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU,
772 {"attr_name|string|'b'", "a|int|8589934592",
773 "b|list(int)|[-8589934592]"});
774 get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
775 get_attr_kernel->ExpectOk({"i_list"}); // no i32_list
776 EXPECT_EQ(std::vector<int64>({-8589934592ll}), get_attr_kernel->i_list);
777 for (const auto& key_status : get_attr_kernel->status) {
778 if (key_status.first == "i32_list") {
779 EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code());
780 EXPECT_EQ("Attr b has value -8589934592 out of range for an int32",
781 key_status.second.error_message());
782 }
783 }
784 }
785
786 REGISTER_OP("GetAttrShape")
787 .Attr("attr_name: string")
788 .Attr("a: shape")
789 .Attr("b: list(shape)");
790 REGISTER_KERNEL_BUILDER(Name("GetAttrShape").Device(DEVICE_CPU), GetAttrKernel);
791
TEST_F(GetAttrTest,Shape)792 TEST_F(GetAttrTest, Shape) {
793 std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
794 "GetAttrShape", DEVICE_CPU,
795 {"attr_name|string|'a'", "a|shape|{ dim { size: 3 } }",
796 "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"});
797 auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
798 get_attr_kernel->ExpectOk({"shape", "shape_proto"});
799 EXPECT_EQ(get_attr_kernel->shape_proto.ShortDebugString(), "dim { size: 3 }");
800 EXPECT_EQ("[3]", get_attr_kernel->shape.DebugString());
801
802 op_kernel = ExpectSuccess(
803 "GetAttrShape", DEVICE_CPU,
804 {"attr_name|string|'b'", "a|shape|{ dim { size: 3 } }",
805 "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"});
806 get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
807 get_attr_kernel->ExpectOk({"shape_list", "shape_proto_list"});
808 ASSERT_EQ(2, get_attr_kernel->shape_proto_list.size());
809 EXPECT_EQ(get_attr_kernel->shape_proto_list[0].ShortDebugString(),
810 "dim { size: 2 }");
811 EXPECT_EQ(get_attr_kernel->shape_proto_list[1].ShortDebugString(),
812 "dim { size: 4 }");
813 ASSERT_EQ(2, get_attr_kernel->shape_list.size());
814 EXPECT_EQ("[2]", get_attr_kernel->shape_list[0].DebugString());
815 EXPECT_EQ("[4]", get_attr_kernel->shape_list[1].DebugString());
816 }
817
818 REGISTER_OP("GetAttrType").Attr("attr_name: string").Attr("a: type");
819 REGISTER_KERNEL_BUILDER(Name("GetAttrType").Device(DEVICE_CPU), GetAttrKernel);
820
TEST_F(GetAttrTest,Type)821 TEST_F(GetAttrTest, Type) {
822 std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
823 "GetAttrType", DEVICE_CPU, {"attr_name|string|'a'", "a|type|DT_FLOAT"});
824 auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
825 get_attr_kernel->ExpectOk({"type"});
826 EXPECT_EQ(DT_FLOAT, get_attr_kernel->type);
827 }
828
829 REGISTER_OP("GetAttrTypeList").Attr("attr_name: string").Attr("a: list(type)");
830 REGISTER_KERNEL_BUILDER(Name("GetAttrTypeList").Device(DEVICE_CPU),
831 GetAttrKernel);
832
TEST_F(GetAttrTest,TypeList)833 TEST_F(GetAttrTest, TypeList) {
834 std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
835 "GetAttrTypeList", DEVICE_CPU,
836 {"attr_name|string|'a'", "a|list(type)|[DT_INT32, DT_BOOL]"});
837 auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
838
839 get_attr_kernel->ExpectOk({"type_list", "type_vector"});
840 ASSERT_EQ(2, get_attr_kernel->type_list.size());
841 EXPECT_EQ(DT_INT32, get_attr_kernel->type_list[0]);
842 EXPECT_EQ(DT_BOOL, get_attr_kernel->type_list[1]);
843 ASSERT_EQ(2, get_attr_kernel->type_vector.size());
844 EXPECT_EQ(DT_INT32, get_attr_kernel->type_vector[0]);
845 EXPECT_EQ(DT_BOOL, get_attr_kernel->type_vector[1]);
846 }
847
848 class BaseKernel : public ::tensorflow::OpKernel {
849 public:
BaseKernel(OpKernelConstruction * context)850 explicit BaseKernel(OpKernelConstruction* context) : OpKernel(context) {}
Compute(::tensorflow::OpKernelContext * context)851 void Compute(::tensorflow::OpKernelContext* context) override {}
852 virtual int Which() const = 0;
853 };
854
855 template <int WHICH>
856 class LabeledKernel : public BaseKernel {
857 public:
858 using BaseKernel::BaseKernel;
Which() const859 int Which() const override { return WHICH; }
860 };
861
862 class LabelTest : public OpKernelBuilderTest {};
863
864 REGISTER_OP("LabeledKernel");
865 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU),
866 LabeledKernel<0>);
867 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("one"),
868 LabeledKernel<1>);
869 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"),
870 LabeledKernel<2>);
871 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"),
872 LabeledKernel<3>);
873
TEST_F(LabelTest,Default)874 TEST_F(LabelTest, Default) {
875 std::unique_ptr<OpKernel> op_kernel =
876 ExpectSuccess("LabeledKernel", DEVICE_CPU, {});
877 auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
878 EXPECT_EQ(0, get_labeled_kernel->Which());
879
880 EXPECT_EQ("LabeledKernel<0>",
881 GetKernelClassName("LabeledKernel", DEVICE_CPU, {}));
882 }
883
TEST_F(LabelTest,Specified)884 TEST_F(LabelTest, Specified) {
885 std::unique_ptr<OpKernel> op_kernel =
886 ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"});
887 auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
888 EXPECT_EQ(1, get_labeled_kernel->Which());
889 EXPECT_EQ("LabeledKernel<1>", GetKernelClassName("LabeledKernel", DEVICE_CPU,
890 {"_kernel|string|'one'"}));
891 }
892
TEST_F(LabelTest,Duplicate)893 TEST_F(LabelTest, Duplicate) {
894 ExpectFailure("LabeledKernel", DEVICE_CPU, {"_kernel|string|'dupe'"},
895 error::INVALID_ARGUMENT);
896 }
897
BM_InputRangeHelper(int iters,const NodeDef & node_def,const char * input_name,int expected_start,int expected_stop)898 void BM_InputRangeHelper(int iters, const NodeDef& node_def,
899 const char* input_name, int expected_start,
900 int expected_stop) {
901 Status status;
902 std::unique_ptr<DummyDevice> device(new DummyDevice(Env::Default(), false));
903
904 std::unique_ptr<OpKernel> op(CreateOpKernel(DEVICE_CPU, device.get(),
905 cpu_allocator(), node_def,
906 TF_GRAPH_DEF_VERSION, &status));
907 TF_CHECK_OK(status);
908
909 testing::StartTiming();
910 for (int i = 0; i < iters; ++i) {
911 int start;
912 int stop;
913 TF_CHECK_OK(op->InputRange(input_name, &start, &stop));
914 EXPECT_EQ(expected_start, start);
915 EXPECT_EQ(expected_stop, stop);
916 }
917 testing::StopTiming();
918 }
919
920 REGISTER_KERNEL_BUILDER(Name("ConcatV2").Device(DEVICE_CPU), DummyKernel);
921 REGISTER_KERNEL_BUILDER(Name("Select").Device(DEVICE_CPU), DummyKernel);
922
BM_ConcatInputRange(int iters)923 void BM_ConcatInputRange(int iters) {
924 testing::StopTiming();
925
926 // Create a ConcatV2 NodeDef with 4 inputs (plus the axis).
927 NodeDef node_def;
928 node_def.set_name("concat-op");
929 node_def.set_op("ConcatV2");
930 AttrValue attr_N;
931 attr_N.set_i(4);
932 AttrValue attr_T;
933 attr_T.set_type(DT_FLOAT);
934 AttrValue attr_Tidx;
935 attr_Tidx.set_type(DT_INT32);
936 node_def.mutable_attr()->insert({"N", attr_N});
937 node_def.mutable_attr()->insert({"T", attr_T});
938 node_def.mutable_attr()->insert({"Tidx", attr_Tidx});
939 for (size_t i = 0; i < 5; ++i) {
940 node_def.add_input(strings::StrCat("a:", i));
941 }
942
943 BM_InputRangeHelper(iters, node_def, "values", 0, 4);
944 }
945
BM_SelectInputRange(int iters)946 void BM_SelectInputRange(int iters) {
947 testing::StopTiming();
948
949 // Create a Select NodeDef with 3 inputs.
950 NodeDef node_def;
951 node_def.set_name("select-op");
952 node_def.set_op("Select");
953 AttrValue attr_T;
954 attr_T.set_type(DT_FLOAT);
955 node_def.mutable_attr()->insert({"T", attr_T});
956 for (size_t i = 0; i < 3; ++i) {
957 node_def.add_input(strings::StrCat("a:", i));
958 }
959
960 BM_InputRangeHelper(iters, node_def, "condition", 0, 1);
961 }
962
963 BENCHMARK(BM_ConcatInputRange);
964 BENCHMARK(BM_SelectInputRange);
965
966 } // namespace
967 } // namespace tensorflow
968