1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_kernel.h"
17
18 #include <memory>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/core/framework/allocator.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/attr_value_util.h"
25 #include "tensorflow/core/framework/fake_input.h"
26 #include "tensorflow/core/framework/node_def_builder.h"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/framework/tensor_shape.pb.h"
29 #include "tensorflow/core/framework/tensor_util.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/graph/graph.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/protobuf.h"
38 #include "tensorflow/core/platform/test.h"
39 #include "tensorflow/core/platform/test_benchmark.h"
40 #include "tensorflow/core/public/version.h"
41 #include "tensorflow/core/util/device_name_utils.h"
42
43 class DummyKernel : public tensorflow::OpKernel {
44 public:
DummyKernel(tensorflow::OpKernelConstruction * context)45 explicit DummyKernel(tensorflow::OpKernelConstruction* context)
46 : OpKernel(context) {}
Compute(tensorflow::OpKernelContext * context)47 void Compute(tensorflow::OpKernelContext* context) override {}
48 };
49
50 // Test that registration works outside a namespace.
51 REGISTER_OP("Test1").Input("a: float").Input("b: int32").Output("o: uint8");
52 REGISTER_KERNEL_BUILDER(Name("Test1").Device(tensorflow::DEVICE_CPU),
53 DummyKernel);
54
55 namespace foo {
56 bool match_signature_ = false;
57
58 // Test that registration works inside a different namespace.
59 class TestOp2 : public ::tensorflow::OpKernel {
60 public:
TestOp2(::tensorflow::OpKernelConstruction * context)61 explicit TestOp2(::tensorflow::OpKernelConstruction* context)
62 : OpKernel(context) {
63 ::tensorflow::Status status = context->MatchSignature(
64 {::tensorflow::DT_INT32}, {::tensorflow::DT_INT32});
65 match_signature_ = status.ok();
66 context->SetStatus(status);
67 }
Compute(::tensorflow::OpKernelContext * context)68 void Compute(::tensorflow::OpKernelContext* context) override {}
69 };
70
71 REGISTER_OP("Test2").Input("i: T").Output("o: T").Attr("T: type");
72 REGISTER_KERNEL_BUILDER(Name("Test2")
73 .Device(::tensorflow::DEVICE_GPU)
74 .HostMemory("i")
75 .HostMemory("o"),
76 TestOp2);
77 } // namespace foo
78
79 namespace tensorflow {
80
81 // Two operations with the same name but different devices.
82 REGISTER_OP("Test3").Input("a: T").Input("b: T").Attr("T: type");
83
84 class TestOp3Cpu : public tensorflow::OpKernel {
85 public:
TestOp3Cpu(OpKernelConstruction * context)86 explicit TestOp3Cpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)87 void Compute(OpKernelContext* context) override {}
88 };
89
90 REGISTER_KERNEL_BUILDER(
91 Name("Test3").Device(DEVICE_CPU).TypeConstraint<int8>("T"), TestOp3Cpu);
92
93 namespace {
94
95 class TestOp3Gpu : public tensorflow::OpKernel {
96 public:
TestOp3Gpu(OpKernelConstruction * context)97 explicit TestOp3Gpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)98 void Compute(OpKernelContext* context) override {}
99 };
100
101 REGISTER_KERNEL_BUILDER(
102 Name("Test3").Device(DEVICE_GPU).TypeConstraint<float>("T"), TestOp3Cpu);
103
104 // An Op registered for both
105 REGISTER_OP("Test4").Input("i: float").Output("o: float");
106 REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_CPU), DummyKernel);
107 REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_GPU), DummyKernel);
108
109 // Kernels with different priorities.
110 REGISTER_OP("Test5").Input("a: T").Input("b: T").Attr("T: type");
111
112 REGISTER_OP("OpWithoutKernel").Input("a: T").Input("b: T").Attr("T: type");
113
114 class TestOp5Cpu : public tensorflow::OpKernel {
115 public:
TestOp5Cpu(OpKernelConstruction * context)116 explicit TestOp5Cpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)117 void Compute(OpKernelContext* context) override {}
118 };
119
120 REGISTER_KERNEL_BUILDER(Name("Test5").Device(DEVICE_CPU).Priority(2),
121 TestOp5Cpu);
122
123 class TestOp5Gpu : public tensorflow::OpKernel {
124 public:
TestOp5Gpu(OpKernelConstruction * context)125 explicit TestOp5Gpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)126 void Compute(OpKernelContext* context) override {}
127 };
128
129 REGISTER_KERNEL_BUILDER(Name("Test5").Device(DEVICE_GPU).Priority(1),
130 TestOp5Gpu);
131
DeviceTypes()132 static std::vector<DeviceType> DeviceTypes() {
133 return {DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)};
134 }
135
136 class OpKernelTest : public ::testing::Test {
137 public:
OpKernelTest()138 OpKernelTest() : device_(Env::Default()) {}
139
140 protected:
CreateNodeDef(const string & op_type,const DataTypeVector & inputs,const string & device="")141 NodeDef CreateNodeDef(const string& op_type, const DataTypeVector& inputs,
142 const string& device = "") {
143 NodeDefBuilder builder(op_type + "-op", op_type);
144 for (DataType dt : inputs) {
145 builder.Input(FakeInput(dt));
146 }
147 builder.Device(device);
148 NodeDef node_def;
149 TF_CHECK_OK(builder.Finalize(&node_def));
150 return node_def;
151 }
152
ExpectEqual(const string & what,const DataTypeVector & expected,const DataTypeVector & observed)153 void ExpectEqual(const string& what, const DataTypeVector& expected,
154 const DataTypeVector& observed) {
155 EXPECT_EQ(expected.size(), observed.size()) << what;
156 const size_t size = std::min(expected.size(), observed.size());
157 for (size_t i = 0; i < size; ++i) {
158 bool match = TypesCompatible(expected[i], observed[i]);
159 EXPECT_TRUE(match) << what << " i:" << i << ", expected: " << expected[i]
160 << ", observed: " << observed[i];
161 }
162 }
163
ExpectSuccess(const string & op_type,DeviceType device_type,const DataTypeVector & inputs,const DataTypeVector & outputs)164 void ExpectSuccess(const string& op_type, DeviceType device_type,
165 const DataTypeVector& inputs,
166 const DataTypeVector& outputs) {
167 Status status;
168 std::unique_ptr<OpKernel> op(CreateOpKernel(
169 std::move(device_type), &device_, cpu_allocator(),
170 CreateNodeDef(op_type, inputs), TF_GRAPH_DEF_VERSION, &status));
171 EXPECT_TRUE(status.ok()) << status;
172 EXPECT_TRUE(op != nullptr);
173 if (op != nullptr) {
174 ExpectEqual("inputs", op->input_types(), inputs);
175 ExpectEqual("outputs", op->output_types(), outputs);
176 }
177 }
178
ExpectFailure(const string & ascii_node_def,DeviceType device_type,error::Code code)179 void ExpectFailure(const string& ascii_node_def, DeviceType device_type,
180 error::Code code) {
181 NodeDef node_def;
182 protobuf::TextFormat::ParseFromString(ascii_node_def, &node_def);
183 Status status;
184 std::unique_ptr<OpKernel> op(
185 CreateOpKernel(std::move(device_type), &device_, cpu_allocator(),
186 node_def, TF_GRAPH_DEF_VERSION, &status));
187 EXPECT_TRUE(op == nullptr);
188 EXPECT_FALSE(status.ok());
189 if (!status.ok()) {
190 LOG(INFO) << "Status message: " << status.error_message();
191 EXPECT_EQ(code, status.code());
192 }
193 }
194
195 private:
196 DeviceBase device_;
197 };
198
TEST_F(OpKernelTest,SuccessCpu)199 TEST_F(OpKernelTest, SuccessCpu) {
200 ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT, DT_INT32}, {DT_UINT8});
201 ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT_REF, DT_INT32}, {DT_UINT8});
202 }
203
TEST_F(OpKernelTest,SuccessGpu)204 TEST_F(OpKernelTest, SuccessGpu) {
205 foo::match_signature_ = false;
206 ExpectSuccess("Test2", DEVICE_GPU, {DT_INT32}, {DT_INT32});
207 EXPECT_TRUE(foo::match_signature_);
208 }
209
TEST_F(OpKernelTest,SuccessBothCpuAndGpu)210 TEST_F(OpKernelTest, SuccessBothCpuAndGpu) {
211 ExpectSuccess("Test3", DEVICE_CPU, {DT_INT8, DT_INT8}, {});
212 ExpectSuccess("Test3", DEVICE_GPU, {DT_FLOAT, DT_FLOAT}, {});
213 }
214
TEST_F(OpKernelTest,CpuTypeRegistered)215 TEST_F(OpKernelTest, CpuTypeRegistered) {
216 NodeDef ndef = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
217 PrioritizedDeviceTypeVector devs;
218 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
219 EXPECT_EQ(1, devs.size());
220 EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
221 }
222
TEST_F(OpKernelTest,KernelNotRegistered)223 TEST_F(OpKernelTest, KernelNotRegistered) {
224 const string& local_device = "/job:localhost/replica:0/task:0/device:CPU:0";
225 const string& remote_device = "/job:worker/replica:0/task:0/device";
226 {
227 // Try a node def of an op which does not have kernel. And the requested
228 // device in NodeDef is on a different address space than the local device.
229 NodeDef ndef =
230 CreateNodeDef("OpWithoutKernel", {DT_STRING, DT_STRING}, remote_device);
231 PrioritizedDeviceTypeVector devs;
232 DeviceNameUtils::ParsedName local_device_name;
233 DeviceNameUtils::ParseFullName(local_device, &local_device_name);
234 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs,
235 &local_device_name));
236 EXPECT_EQ(2, devs.size());
237 EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first);
238 EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1].first);
239 }
240
241 {
242 // Try a node def of an op which does not have kernel. And the requested
243 // device in NodeDef is on the same address space as the local device.
244 NodeDef ndef =
245 CreateNodeDef("OpWithoutKernel", {DT_STRING, DT_STRING}, local_device);
246 PrioritizedDeviceTypeVector devs;
247 DeviceNameUtils::ParsedName local_device_name;
248 DeviceNameUtils::ParseFullName(local_device, &local_device_name);
249 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs,
250 &local_device_name));
251 EXPECT_EQ(0, devs.size());
252 }
253 }
254
TEST_F(OpKernelTest,CpuAndGpuTypeRegistered)255 TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) {
256 {
257 // Try a node def of an op that is registered for a specific type
258 // only on CPU.
259 NodeDef ndef = CreateNodeDef("Test3", {DT_INT8, DT_INT8});
260 PrioritizedDeviceTypeVector devs;
261 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
262 EXPECT_EQ(1, devs.size());
263 EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
264 }
265 {
266 // Try a node def of an op that is registered for a specific type
267 // only on GPU.
268 NodeDef ndef = CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT});
269 PrioritizedDeviceTypeVector devs;
270 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
271 EXPECT_EQ(1, devs.size());
272 EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first);
273 }
274 {
275 // Try a node def of an op that is only registered for other types.
276 NodeDef ndef = CreateNodeDef("Test3", {DT_STRING, DT_STRING});
277 PrioritizedDeviceTypeVector devs;
278 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
279 EXPECT_EQ(0, devs.size());
280 }
281
282 {
283 // Try a node def of an op that is registered for both.
284 NodeDef ndef = CreateNodeDef("Test4", {DT_FLOAT});
285 PrioritizedDeviceTypeVector devs;
286 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
287 EXPECT_EQ(2, devs.size());
288 EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first);
289 EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1].first);
290 }
291
292 {
293 // Try a node def of an op where kernels have priorities.
294 NodeDef ndef = CreateNodeDef("Test5", {DT_STRING, DT_STRING});
295 PrioritizedDeviceTypeVector devs;
296 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
297 EXPECT_EQ(2, devs.size());
298 EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
299 EXPECT_EQ(2, devs[0].second);
300 EXPECT_EQ(DeviceType(DEVICE_GPU), devs[1].first);
301 EXPECT_EQ(1, devs[1].second);
302 }
303 }
304
TEST_F(OpKernelTest,NotFound)305 TEST_F(OpKernelTest, NotFound) {
306 const auto not_found = error::NOT_FOUND;
307 // Something with that op type name exists, but only with a
308 // different DeviceType.
309 ExpectFailure(CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}).DebugString(),
310 DEVICE_GPU, not_found);
311 ExpectFailure(CreateNodeDef("Test3", {DT_INT8, DT_INT8}).DebugString(),
312 DEVICE_GPU, not_found);
313 ExpectFailure(CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}).DebugString(),
314 DEVICE_CPU, not_found);
315
316 // No kernel with that signature registered.
317 ExpectFailure(CreateNodeDef("Test3", {DT_INT32, DT_INT32}).DebugString(),
318 DEVICE_GPU, not_found);
319
320 // Nothing with that op type name exists.
321 ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_CPU, not_found);
322 ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_GPU, not_found);
323 }
324
TEST_F(OpKernelTest,TooFewInputs)325 TEST_F(OpKernelTest, TooFewInputs) {
326 const auto invalid = error::INVALID_ARGUMENT;
327 NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
328 node_def.clear_input();
329 ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
330 node_def.add_input("a");
331 ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
332 }
333
TEST_F(OpKernelTest,TooManyInputs)334 TEST_F(OpKernelTest, TooManyInputs) {
335 const auto invalid = error::INVALID_ARGUMENT;
336 NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
337 node_def.add_input("c");
338 ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
339 }
340
TEST_F(OpKernelTest,MatchSignatureFailes)341 TEST_F(OpKernelTest, MatchSignatureFailes) {
342 const auto invalid = error::INVALID_ARGUMENT;
343 foo::match_signature_ = true;
344 ExpectFailure(CreateNodeDef("Test2", {DT_FLOAT}).DebugString(), DEVICE_GPU,
345 invalid);
346 EXPECT_FALSE(foo::match_signature_);
347 }
348
349 class DummyDevice : public DeviceBase {
350 public:
DummyDevice(Env * env)351 explicit DummyDevice(Env* env) : DeviceBase(env) {}
GetAllocator(AllocatorAttributes)352 Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
353 return cpu_allocator();
354 }
355 };
356
TEST_F(OpKernelTest,InputDtype)357 TEST_F(OpKernelTest, InputDtype) {
358 Env* env = Env::Default();
359 OpKernelContext::Params params;
360 DummyDevice device(env);
361 params.device = &device;
362 Status status;
363 std::unique_ptr<OpKernel> op(
364 CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
365 CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
366 TF_GRAPH_DEF_VERSION, &status));
367 EXPECT_TRUE(status.ok());
368 params.op_kernel = op.get();
369 Tensor a(DT_FLOAT, TensorShape({}));
370 Tensor b(DT_INT32, TensorShape({}));
371 Tensor c(DT_UINT8, TensorShape({}));
372 gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&a), TensorValue(&b),
373 TensorValue(&c)};
374 params.inputs = &inputs;
375 auto ctx = absl::make_unique<OpKernelContext>(¶ms);
376
377 DataType dtype;
378 EXPECT_FALSE(ctx->input_dtype("non_existent_input", &dtype).ok());
379 ASSERT_TRUE(ctx->input_dtype("a", &dtype).ok());
380 EXPECT_EQ(dtype, DT_FLOAT);
381 ASSERT_TRUE(ctx->input_dtype("b", &dtype).ok());
382 EXPECT_EQ(dtype, DT_INT32);
383 }
384
385 // A mock device that mimics the behavior of scoped allocator upon calling
386 // GetAllocator with a positive scope_id.
387 class ScopedAllocatorDevice : public DeviceBase {
388 public:
ScopedAllocatorDevice(Env * env)389 explicit ScopedAllocatorDevice(Env* env)
390 : DeviceBase(env),
391 scope_allocated_(false),
392 num_allocations_(0),
393 num_scoped_allocations_(0) {}
394
GetAllocator(AllocatorAttributes attrs)395 Allocator* GetAllocator(AllocatorAttributes attrs) override {
396 CHECK_LE(attrs.scope_id, 0);
397 num_allocations_++;
398 return cpu_allocator();
399 }
400
GetScopedAllocator(AllocatorAttributes attrs,int64_t)401 Allocator* GetScopedAllocator(AllocatorAttributes attrs,
402 int64_t /*step_id*/) override {
403 CHECK_GT(attrs.scope_id, 0);
404 num_scoped_allocations_++;
405 if (scope_allocated_) {
406 return nullptr;
407 } else {
408 scope_allocated_ = true;
409 return cpu_allocator();
410 }
411 }
412
CopyTensorInSameDevice(const Tensor * input_tensor,Tensor * output_tensor,const DeviceContext * device_context,StatusCallback done)413 void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor,
414 const DeviceContext* device_context,
415 StatusCallback done) override {
416 CHECK(input_tensor->NumElements() == output_tensor->NumElements());
417 tensor::DeepCopy(*input_tensor, output_tensor);
418 done(Status::OK());
419 }
420
421 // Return the count of calls to GetAllocator or GetScopedAllocator, depending
422 // on when scoped is false or true respectively. For testing purposes.
num_allocations(bool scoped)423 int num_allocations(bool scoped) {
424 if (scoped) {
425 return num_scoped_allocations_;
426 } else {
427 return num_allocations_;
428 }
429 }
430
431 private:
432 bool scope_allocated_;
433 int num_allocations_;
434 int num_scoped_allocations_;
435 };
436
437 // Test that a kernel which has an output marked for allocation via
438 // ScopedAllocator, which calls allocate_temp and set_output, does the right
439 // thing. In this case, the expected behavior is for allocate_temp to return
440 // a temporary buffer, and set_output to copy the contents of this temp buffer
441 // into the ScopedAllocator slice.
TEST_F(OpKernelTest,ScopedAllocationTest)442 TEST_F(OpKernelTest, ScopedAllocationTest) {
443 Env* env = Env::Default();
444 OpKernelContext::Params params;
445 auto sa_device = absl::make_unique<ScopedAllocatorDevice>(env);
446 params.device = sa_device.get();
447 Status status;
448 std::unique_ptr<OpKernel> op(CreateOpKernel(
449 DEVICE_CPU, params.device, cpu_allocator(),
450 CreateNodeDef("Test4", {DT_FLOAT}), TF_GRAPH_DEF_VERSION, &status));
451 EXPECT_TRUE(status.ok());
452 params.op_kernel = op.get();
453 AllocatorAttributes alloc_attrs;
454 alloc_attrs.scope_id = 1;
455 std::vector<AllocatorAttributes> output_alloc_attrs({alloc_attrs});
456 params.output_attr_array = output_alloc_attrs.data();
457 std::vector<int> forward_from({OpKernelContext::Params::kNeverForward});
458 params.forward_from_array = forward_from.data();
459 auto ctx = absl::make_unique<OpKernelContext>(¶ms);
460
461 EXPECT_EQ(sa_device->num_allocations(false), 0);
462 EXPECT_EQ(sa_device->num_allocations(true), 0);
463 Tensor temp1;
464 TF_EXPECT_OK(
465 ctx->allocate_temp(DT_FLOAT, TensorShape({8}), &temp1, alloc_attrs));
466 EXPECT_EQ(sa_device->num_allocations(false), 1);
467 EXPECT_EQ(sa_device->num_allocations(true), 0);
468 Tensor temp2;
469 alloc_attrs.scope_id = -1;
470 TF_EXPECT_OK(
471 ctx->allocate_temp(DT_FLOAT, TensorShape({4}), &temp2, alloc_attrs));
472 EXPECT_EQ(sa_device->num_allocations(false), 2);
473 EXPECT_EQ(sa_device->num_allocations(true), 0);
474 ctx->set_output(0, temp1);
475 EXPECT_EQ(sa_device->num_allocations(false), 2);
476 EXPECT_EQ(sa_device->num_allocations(true), 1);
477 }
478
479 class OpKernelBuilderTest : public ::testing::Test {
480 protected:
481 // Each attr is described by a "name|type|value".
CreateNodeDef(const string & op_type,const std::vector<string> & attrs)482 NodeDef CreateNodeDef(const string& op_type,
483 const std::vector<string>& attrs) {
484 NodeDef node_def;
485 node_def.set_name(op_type + "-op");
486 node_def.set_op(op_type);
487 for (const string& attr_desc : attrs) {
488 std::vector<string> parts = str_util::Split(attr_desc, '|');
489 CHECK_EQ(parts.size(), 3);
490 AttrValue attr_value;
491 CHECK(ParseAttrValue(parts[1], parts[2], &attr_value)) << attr_desc;
492 node_def.mutable_attr()->insert(
493 AttrValueMap::value_type(parts[0], attr_value));
494 }
495 return node_def;
496 }
497
ExpectSuccess(const string & op_type,const DeviceType & device_type,const std::vector<string> & attrs,DataTypeSlice input_types={})498 std::unique_ptr<OpKernel> ExpectSuccess(const string& op_type,
499 const DeviceType& device_type,
500 const std::vector<string>& attrs,
501 DataTypeSlice input_types = {}) {
502 Status status;
503 NodeDef def = CreateNodeDef(op_type, attrs);
504 for (size_t i = 0; i < input_types.size(); ++i) {
505 def.add_input("a:0");
506 }
507
508 Env* env = Env::Default();
509 DeviceBase device(env);
510
511 // Test CreateOpKernel()
512 std::unique_ptr<OpKernel> op(CreateOpKernel(device_type, &device,
513 cpu_allocator(), def,
514 TF_GRAPH_DEF_VERSION, &status));
515 EXPECT_TRUE(status.ok()) << status;
516 EXPECT_TRUE(op != nullptr);
517 if (op != nullptr) {
518 EXPECT_EQ(input_types.size(), op->num_inputs());
519 EXPECT_EQ(0, op->num_outputs());
520 }
521
522 // Test SupportedDeviceTypesForNode()
523 PrioritizedDeviceTypeVector devices;
524 TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
525 bool found = false;
526 for (const auto& dt : devices) {
527 if (dt.first == device_type) {
528 found = true;
529 }
530 }
531 EXPECT_TRUE(found) << "Missing " << device_type << " from "
532 << devices.size() << " devices.";
533
534 // In case the caller wants to use the OpKernel
535 return op;
536 }
537
ExpectFailure(const string & op_type,const DeviceType & device_type,const std::vector<string> & attrs,error::Code code)538 void ExpectFailure(const string& op_type, const DeviceType& device_type,
539 const std::vector<string>& attrs, error::Code code) {
540 Status status;
541 const NodeDef def = CreateNodeDef(op_type, attrs);
542 Env* env = Env::Default();
543 DeviceBase device(env);
544
545 // Test CreateOpKernel().
546 std::unique_ptr<OpKernel> op(CreateOpKernel(device_type, &device,
547 cpu_allocator(), def,
548 TF_GRAPH_DEF_VERSION, &status));
549 EXPECT_TRUE(op == nullptr);
550 EXPECT_FALSE(status.ok());
551 if (!status.ok()) {
552 LOG(INFO) << "Status message: " << status.error_message();
553 EXPECT_EQ(code, status.code());
554
555 // Test SupportedDeviceTypesForNode().
556 PrioritizedDeviceTypeVector devices;
557 if (errors::IsNotFound(status)) {
558 TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
559 for (const auto& dt : devices) {
560 EXPECT_NE(dt.first, device_type);
561 }
562 } else {
563 Status status2 =
564 SupportedDeviceTypesForNode(DeviceTypes(), def, &devices);
565 EXPECT_EQ(status.code(), status2.code());
566 }
567 }
568 }
569
GetKernelClassName(const string & op_type,const DeviceType & device_type,const std::vector<string> & attrs,DataTypeSlice input_types={})570 string GetKernelClassName(const string& op_type,
571 const DeviceType& device_type,
572 const std::vector<string>& attrs,
573 DataTypeSlice input_types = {}) {
574 NodeDef def = CreateNodeDef(op_type, attrs);
575 for (size_t i = 0; i < input_types.size(); ++i) {
576 def.add_input("a:0");
577 }
578
579 const KernelDef* kernel_def = nullptr;
580 string kernel_class_name;
581 const Status status =
582 FindKernelDef(device_type, def, &kernel_def, &kernel_class_name);
583 if (status.ok()) {
584 return kernel_class_name;
585 } else if (errors::IsNotFound(status)) {
586 return "not found";
587 } else {
588 return status.ToString();
589 }
590 }
591 };
592
593 REGISTER_OP("BuildCPU");
594 REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel);
595
TEST_F(OpKernelBuilderTest,BuilderCPU)596 TEST_F(OpKernelBuilderTest, BuilderCPU) {
597 ExpectSuccess("BuildCPU", DEVICE_CPU, {});
598 EXPECT_EQ("DummyKernel", GetKernelClassName("BuildCPU", DEVICE_CPU, {}));
599 ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND);
600 EXPECT_EQ("not found", GetKernelClassName("BuildCPU", DEVICE_GPU, {}));
601 }
602
603 REGISTER_OP("BuildGPU");
604 REGISTER_KERNEL_BUILDER(Name("BuildGPU").Device(DEVICE_GPU), DummyKernel);
605
TEST_F(OpKernelBuilderTest,BuilderGPU)606 TEST_F(OpKernelBuilderTest, BuilderGPU) {
607 ExpectFailure("BuildGPU", DEVICE_CPU, {}, error::NOT_FOUND);
608 ExpectSuccess("BuildGPU", DEVICE_GPU, {});
609 }
610
611 REGISTER_OP("BuildBoth");
612 REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_CPU), DummyKernel);
613 REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_GPU), DummyKernel);
614
TEST_F(OpKernelBuilderTest,BuilderBoth)615 TEST_F(OpKernelBuilderTest, BuilderBoth) {
616 ExpectSuccess("BuildBoth", DEVICE_CPU, {});
617 ExpectSuccess("BuildBoth", DEVICE_GPU, {});
618 }
619
620 REGISTER_OP("BuildTypeAttr").Attr("T: type");
621 REGISTER_KERNEL_BUILDER(
622 Name("BuildTypeAttr").Device(DEVICE_CPU).TypeConstraint<float>("T"),
623 DummyKernel);
624
TEST_F(OpKernelBuilderTest,BuilderTypeAttr)625 TEST_F(OpKernelBuilderTest, BuilderTypeAttr) {
626 ExpectSuccess("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_FLOAT"});
627 ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_BOOL"},
628 error::NOT_FOUND);
629 ExpectFailure("BuildTypeAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
630 ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|int|7"},
631 error::INVALID_ARGUMENT);
632 }
633
634 REGISTER_OP("BuildTypeListAttr").Attr("T: list(type)");
635 REGISTER_KERNEL_BUILDER(
636 Name("BuildTypeListAttr").Device(DEVICE_CPU).TypeConstraint<bool>("T"),
637 DummyKernel);
638
TEST_F(OpKernelBuilderTest,BuilderTypeListAttr)639 TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) {
640 ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"});
641 EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
642 {"T|list(type)|[]"}));
643
644 ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"});
645 EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
646 {"T|list(type)|[]"}));
647
648 ExpectSuccess("BuildTypeListAttr", DEVICE_CPU,
649 {"T|list(type)|[DT_BOOL, DT_BOOL]"});
650
651 ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"},
652 error::NOT_FOUND);
653 EXPECT_EQ("not found", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
654 {"T|list(type)|[DT_FLOAT]"}));
655
656 ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
657 EXPECT_TRUE(
658 absl::StrContains(GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {}),
659 "Invalid argument: "));
660
661 ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"},
662 error::INVALID_ARGUMENT);
663 }
664
665 REGISTER_OP("DuplicateKernel");
666 REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU),
667 DummyKernel);
668 REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU),
669 DummyKernel);
670
TEST_F(OpKernelBuilderTest,DuplicateKernel)671 TEST_F(OpKernelBuilderTest, DuplicateKernel) {
672 const NodeDef ndef = CreateNodeDef("DuplicateKernel", {});
673 PrioritizedDeviceTypeVector devs;
674 Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
675 ASSERT_FALSE(status.ok());
676 EXPECT_TRUE(absl::StrContains(
677 status.error_message(), "Multiple OpKernel registrations match NodeDef"));
678
679 ExpectFailure("DuplicateKernel", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
680 }
681
682 REGISTER_OP("DuplicateKernelForT").Attr("T: type");
683 REGISTER_KERNEL_BUILDER(
684 Name("DuplicateKernelForT").Device(DEVICE_CPU).TypeConstraint<float>("T"),
685 DummyKernel);
686 REGISTER_KERNEL_BUILDER(
687 Name("DuplicateKernelForT").Device(DEVICE_CPU).TypeConstraint<float>("T"),
688 DummyKernel);
689
TEST_F(OpKernelBuilderTest,DuplicateKernelForT)690 TEST_F(OpKernelBuilderTest, DuplicateKernelForT) {
691 const NodeDef ndef =
692 CreateNodeDef("DuplicateKernelForT", {"T|type|DT_FLOAT"});
693 PrioritizedDeviceTypeVector devs;
694 Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
695 ASSERT_FALSE(status.ok());
696 EXPECT_TRUE(absl::StrContains(
697 status.error_message(), "Multiple OpKernel registrations match NodeDef"));
698
699 ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_FLOAT"},
700 error::INVALID_ARGUMENT);
701 ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_BOOL"},
702 error::NOT_FOUND);
703 }
704
705 REGISTER_OP("BadConstraint").Attr("dtype: type");
706 REGISTER_KERNEL_BUILDER(Name("BadConstraint")
707 .Device(DEVICE_CPU)
708 // Mistake: "T" should be "dtype".
709 .TypeConstraint<float>("T"),
710 DummyKernel);
711
TEST_F(OpKernelBuilderTest,BadConstraint)712 TEST_F(OpKernelBuilderTest, BadConstraint) {
713 const NodeDef ndef = CreateNodeDef("BadConstraint", {});
714 PrioritizedDeviceTypeVector devs;
715 Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
716 ASSERT_FALSE(status.ok());
717 EXPECT_TRUE(
718 absl::StrContains(status.error_message(),
719 "OpKernel 'BadConstraint' has constraint on attr "
720 "'T' not in NodeDef"));
721
722 ExpectFailure("BadConstraint", DEVICE_CPU, {"dtype|type|DT_FLOAT"},
723 error::INVALID_ARGUMENT);
724 }
725
726 REGISTER_OP("ListOut").Output("a: int32").Output("b: T").Attr("T: list(type)");
727 REGISTER_KERNEL_BUILDER(Name("ListOut").Device(tensorflow::DEVICE_CPU),
728 DummyKernel);
729
TEST_F(OpKernelBuilderTest,OpOutputList)730 TEST_F(OpKernelBuilderTest, OpOutputList) {
731 Env* env = Env::Default();
732 OpKernelContext::Params params;
733 DummyDevice device(env);
734 params.device = &device;
735 Status status;
736 std::unique_ptr<OpKernel> op(CreateOpKernel(
737 DEVICE_CPU, params.device, cpu_allocator(),
738 CreateNodeDef("ListOut", {"T|list(type)|[DT_FLOAT, DT_INT32]"}),
739 TF_GRAPH_DEF_VERSION, &status));
740 EXPECT_TRUE(status.ok()) << status.ToString();
741 params.op_kernel = op.get();
742 gtl::InlinedVector<TensorValue, 4> inputs{};
743 params.inputs = &inputs;
744 auto ctx = absl::make_unique<OpKernelContext>(¶ms);
745
746 EXPECT_EQ(DT_INT32, ctx->expected_output_dtype(0));
747 OpOutputList out_list;
748 EXPECT_FALSE(ctx->output_list("non_existent_output", &out_list).ok());
749 ASSERT_TRUE(ctx->output_list("b", &out_list).ok());
750 EXPECT_EQ(DT_FLOAT, out_list.expected_output_dtype(0));
751 EXPECT_EQ(DT_INT32, out_list.expected_output_dtype(1));
752 }
753
754 class GetAttrKernel : public ::tensorflow::OpKernel {
755 public:
GetAttrKernel(OpKernelConstruction * context)756 explicit GetAttrKernel(OpKernelConstruction* context) : OpKernel(context) {
757 string attr_name;
758 OP_REQUIRES_OK(context, context->GetAttr("attr_name", &attr_name));
759
760 status.emplace_back("s", context->GetAttr(attr_name, &s));
761 status.emplace_back("s_list", context->GetAttr(attr_name, &s_list));
762 status.emplace_back("i", context->GetAttr(attr_name, &i));
763 status.emplace_back("i_list", context->GetAttr(attr_name, &i_list));
764 status.emplace_back("i32", context->GetAttr(attr_name, &i32));
765 status.emplace_back("i32_list", context->GetAttr(attr_name, &i32_list));
766 status.emplace_back("f", context->GetAttr(attr_name, &f));
767 status.emplace_back("f_list", context->GetAttr(attr_name, &f_list));
768 status.emplace_back("b", context->GetAttr(attr_name, &b));
769 status.emplace_back("b_list", context->GetAttr(attr_name, &b_list));
770 status.emplace_back("type", context->GetAttr(attr_name, &type));
771 status.emplace_back("type_list", context->GetAttr(attr_name, &type_list));
772 status.emplace_back("type_vector",
773 context->GetAttr(attr_name, &type_vector));
774 status.emplace_back("shape_proto",
775 context->GetAttr(attr_name, &shape_proto));
776 status.emplace_back("shape_proto_list",
777 context->GetAttr(attr_name, &shape_proto_list));
778 status.emplace_back("shape", context->GetAttr(attr_name, &shape));
779 status.emplace_back("shape_list", context->GetAttr(attr_name, &shape_list));
780 }
Compute(::tensorflow::OpKernelContext * context)781 void Compute(::tensorflow::OpKernelContext* context) override {}
782
ExpectOk(std::initializer_list<string> keys)783 void ExpectOk(std::initializer_list<string> keys) {
784 for (const auto& key_status : status) {
785 // Only the status for keys in "keys" should be ok().
786 bool in_keys = false;
787 for (const string& key : keys) {
788 if (key_status.first == key) {
789 in_keys = true;
790 }
791 }
792 EXPECT_EQ(in_keys, key_status.second.ok())
793 << "key_status: " << key_status.first << ", " << key_status.second;
794 }
795 }
796
797 string s;
798 std::vector<string> s_list;
799 int64 i;
800 std::vector<int64> i_list;
801 int32 i32;
802 std::vector<int32> i32_list;
803 float f;
804 std::vector<float> f_list;
805 bool b;
806 std::vector<bool> b_list;
807 DataType type;
808 std::vector<DataType> type_list;
809 DataTypeVector type_vector;
810 TensorShapeProto shape_proto;
811 std::vector<TensorShapeProto> shape_proto_list;
812 TensorShape shape;
813 std::vector<TensorShape> shape_list;
814 std::vector<std::pair<string, Status>> status;
815 };
816
817 class GetAttrTest : public OpKernelBuilderTest {};
818
819 REGISTER_OP("GetAttrStringList")
820 .Attr("attr_name: string")
821 .Attr("a: list(string)");
822 REGISTER_KERNEL_BUILDER(Name("GetAttrStringList").Device(DEVICE_CPU),
823 GetAttrKernel);
824
TEST_F(GetAttrTest,StringList)825 TEST_F(GetAttrTest, StringList) {
826 std::unique_ptr<OpKernel> op_kernel =
827 ExpectSuccess("GetAttrStringList", DEVICE_CPU,
828 {"attr_name|string|'a'", "a|list(string)|['foo', 'bar']"});
829 auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
830 get_attr_kernel->ExpectOk({"s_list"});
831 EXPECT_EQ(std::vector<string>({"foo", "bar"}), get_attr_kernel->s_list);
832
833 op_kernel = ExpectSuccess("GetAttrStringList", DEVICE_CPU,
834 {"attr_name|string|'b'", "a|list(string)|['baz']"});
835 get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
836 get_attr_kernel->ExpectOk({});
837 EXPECT_TRUE(get_attr_kernel->s_list.empty());
838 }
839
840 REGISTER_OP("GetAttrInt")
841 .Attr("attr_name: string")
842 .Attr("a: int")
843 .Attr("b: list(int)");
844 REGISTER_KERNEL_BUILDER(Name("GetAttrInt").Device(DEVICE_CPU), GetAttrKernel);
845
TEST_F(GetAttrTest,Int)846 TEST_F(GetAttrTest, Int) {
847 std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
848 "GetAttrInt", DEVICE_CPU,
849 {"attr_name|string|'a'", "a|int|35", "b|list(int)|[-1, 2, -4]"});
850 auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
851 get_attr_kernel->ExpectOk({"i", "i32"});
852 EXPECT_EQ(35, get_attr_kernel->i);
853 EXPECT_EQ(35, get_attr_kernel->i32);
854
855 op_kernel = ExpectSuccess(
856 "GetAttrInt", DEVICE_CPU,
857 {"attr_name|string|'b'", "a|int|35", "b|list(int)|[-1, 2, -4]"});
858 get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
859 get_attr_kernel->ExpectOk({"i_list", "i32_list"});
860 EXPECT_EQ(std::vector<int64>({-1, 2, -4}), get_attr_kernel->i_list);
861 EXPECT_EQ(std::vector<int32>({-1, 2, -4}), get_attr_kernel->i32_list);
862
863 // 8589934592 == 2^33, too big to fit in an int32
864 op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU,
865 {"attr_name|string|'a'", "a|int|8589934592",
866 "b|list(int)|[-8589934592]"});
867 get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
868 get_attr_kernel->ExpectOk({"i"}); // no i32
869 EXPECT_EQ(8589934592ll, get_attr_kernel->i);
870 for (const auto& key_status : get_attr_kernel->status) {
871 if (key_status.first == "i32") {
872 EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code());
873 EXPECT_EQ("Attr a has value 8589934592 out of range for an int32",
874 key_status.second.error_message());
875 }
876 }
877
878 op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU,
879 {"attr_name|string|'b'", "a|int|8589934592",
880 "b|list(int)|[-8589934592]"});
881 get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
882 get_attr_kernel->ExpectOk({"i_list"}); // no i32_list
883 EXPECT_EQ(std::vector<int64>({-8589934592ll}), get_attr_kernel->i_list);
884 for (const auto& key_status : get_attr_kernel->status) {
885 if (key_status.first == "i32_list") {
886 EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code());
887 EXPECT_EQ("Attr b has value -8589934592 out of range for an int32",
888 key_status.second.error_message());
889 }
890 }
891 }
892
893 REGISTER_OP("GetAttrShape")
894 .Attr("attr_name: string")
895 .Attr("a: shape")
896 .Attr("b: list(shape)");
897 REGISTER_KERNEL_BUILDER(Name("GetAttrShape").Device(DEVICE_CPU), GetAttrKernel);
898
TEST_F(GetAttrTest,Shape)899 TEST_F(GetAttrTest, Shape) {
900 std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
901 "GetAttrShape", DEVICE_CPU,
902 {"attr_name|string|'a'", "a|shape|{ dim { size: 3 } }",
903 "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"});
904 auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
905 get_attr_kernel->ExpectOk({"shape", "shape_proto"});
906 EXPECT_EQ(get_attr_kernel->shape_proto.ShortDebugString(), "dim { size: 3 }");
907 EXPECT_EQ("[3]", get_attr_kernel->shape.DebugString());
908
909 op_kernel = ExpectSuccess(
910 "GetAttrShape", DEVICE_CPU,
911 {"attr_name|string|'b'", "a|shape|{ dim { size: 3 } }",
912 "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"});
913 get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
914 get_attr_kernel->ExpectOk({"shape_list", "shape_proto_list"});
915 ASSERT_EQ(2, get_attr_kernel->shape_proto_list.size());
916 EXPECT_EQ(get_attr_kernel->shape_proto_list[0].ShortDebugString(),
917 "dim { size: 2 }");
918 EXPECT_EQ(get_attr_kernel->shape_proto_list[1].ShortDebugString(),
919 "dim { size: 4 }");
920 ASSERT_EQ(2, get_attr_kernel->shape_list.size());
921 EXPECT_EQ("[2]", get_attr_kernel->shape_list[0].DebugString());
922 EXPECT_EQ("[4]", get_attr_kernel->shape_list[1].DebugString());
923 }
924
925 REGISTER_OP("GetAttrType").Attr("attr_name: string").Attr("a: type");
926 REGISTER_KERNEL_BUILDER(Name("GetAttrType").Device(DEVICE_CPU), GetAttrKernel);
927
TEST_F(GetAttrTest,Type)928 TEST_F(GetAttrTest, Type) {
929 std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
930 "GetAttrType", DEVICE_CPU, {"attr_name|string|'a'", "a|type|DT_FLOAT"});
931 auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
932 get_attr_kernel->ExpectOk({"type"});
933 EXPECT_EQ(DT_FLOAT, get_attr_kernel->type);
934 }
935
936 REGISTER_OP("GetAttrTypeList").Attr("attr_name: string").Attr("a: list(type)");
937 REGISTER_KERNEL_BUILDER(Name("GetAttrTypeList").Device(DEVICE_CPU),
938 GetAttrKernel);
939
TEST_F(GetAttrTest,TypeList)940 TEST_F(GetAttrTest, TypeList) {
941 std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
942 "GetAttrTypeList", DEVICE_CPU,
943 {"attr_name|string|'a'", "a|list(type)|[DT_INT32, DT_BOOL]"});
944 auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
945
946 get_attr_kernel->ExpectOk({"type_list", "type_vector"});
947 ASSERT_EQ(2, get_attr_kernel->type_list.size());
948 EXPECT_EQ(DT_INT32, get_attr_kernel->type_list[0]);
949 EXPECT_EQ(DT_BOOL, get_attr_kernel->type_list[1]);
950 ASSERT_EQ(2, get_attr_kernel->type_vector.size());
951 EXPECT_EQ(DT_INT32, get_attr_kernel->type_vector[0]);
952 EXPECT_EQ(DT_BOOL, get_attr_kernel->type_vector[1]);
953 }
954
955 class BaseKernel : public ::tensorflow::OpKernel {
956 public:
BaseKernel(OpKernelConstruction * context)957 explicit BaseKernel(OpKernelConstruction* context) : OpKernel(context) {}
Compute(::tensorflow::OpKernelContext * context)958 void Compute(::tensorflow::OpKernelContext* context) override {}
959 virtual int Which() const = 0;
960 };
961
962 template <int WHICH>
963 class LabeledKernel : public BaseKernel {
964 public:
965 using BaseKernel::BaseKernel;
Which() const966 int Which() const override { return WHICH; }
967 };
968
969 class LabelTest : public OpKernelBuilderTest {};
970
971 REGISTER_OP("LabeledKernel");
972 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU),
973 LabeledKernel<0>);
974 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("one"),
975 LabeledKernel<1>);
976 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"),
977 LabeledKernel<2>);
978 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"),
979 LabeledKernel<3>);
980
TEST_F(LabelTest,Default)981 TEST_F(LabelTest, Default) {
982 std::unique_ptr<OpKernel> op_kernel =
983 ExpectSuccess("LabeledKernel", DEVICE_CPU, {});
984 auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
985 EXPECT_EQ(0, get_labeled_kernel->Which());
986
987 EXPECT_EQ("LabeledKernel<0>",
988 GetKernelClassName("LabeledKernel", DEVICE_CPU, {}));
989 }
990
TEST_F(LabelTest,Specified)991 TEST_F(LabelTest, Specified) {
992 std::unique_ptr<OpKernel> op_kernel =
993 ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"});
994 auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
995 EXPECT_EQ(1, get_labeled_kernel->Which());
996 EXPECT_EQ("LabeledKernel<1>", GetKernelClassName("LabeledKernel", DEVICE_CPU,
997 {"_kernel|string|'one'"}));
998 }
999
TEST_F(LabelTest,Duplicate)1000 TEST_F(LabelTest, Duplicate) {
1001 ExpectFailure("LabeledKernel", DEVICE_CPU, {"_kernel|string|'dupe'"},
1002 error::INVALID_ARGUMENT);
1003 }
1004
BM_InputRangeHelper(::testing::benchmark::State & state,const NodeDef & node_def,const char * input_name,int expected_start,int expected_stop)1005 void BM_InputRangeHelper(::testing::benchmark::State& state,
1006 const NodeDef& node_def, const char* input_name,
1007 int expected_start, int expected_stop) {
1008 Status status;
1009 auto device = absl::make_unique<DummyDevice>(Env::Default());
1010
1011 std::unique_ptr<OpKernel> op(CreateOpKernel(DEVICE_CPU, device.get(),
1012 cpu_allocator(), node_def,
1013 TF_GRAPH_DEF_VERSION, &status));
1014 TF_CHECK_OK(status);
1015
1016 for (auto s : state) {
1017 int start;
1018 int stop;
1019 TF_CHECK_OK(op->InputRange(input_name, &start, &stop));
1020 EXPECT_EQ(expected_start, start);
1021 EXPECT_EQ(expected_stop, stop);
1022 }
1023 }
1024
1025 REGISTER_KERNEL_BUILDER(Name("ConcatV2").Device(DEVICE_CPU), DummyKernel);
1026 REGISTER_KERNEL_BUILDER(Name("Select").Device(DEVICE_CPU), DummyKernel);
1027 REGISTER_KERNEL_BUILDER(Name("MatMul").Device(DEVICE_CPU), DummyKernel);
1028
BM_ConcatInputRange(::testing::benchmark::State & state)1029 void BM_ConcatInputRange(::testing::benchmark::State& state) {
1030 // Create a ConcatV2 NodeDef with 4 inputs (plus the axis).
1031 NodeDef node_def;
1032 node_def.set_name("concat-op");
1033 node_def.set_op("ConcatV2");
1034 AttrValue attr_N;
1035 attr_N.set_i(4);
1036 AttrValue attr_T;
1037 attr_T.set_type(DT_FLOAT);
1038 AttrValue attr_Tidx;
1039 attr_Tidx.set_type(DT_INT32);
1040 node_def.mutable_attr()->insert({"N", attr_N});
1041 node_def.mutable_attr()->insert({"T", attr_T});
1042 node_def.mutable_attr()->insert({"Tidx", attr_Tidx});
1043 for (size_t i = 0; i < 5; ++i) {
1044 node_def.add_input(strings::StrCat("a:", i));
1045 }
1046
1047 BM_InputRangeHelper(state, node_def, "values", 0, 4);
1048 }
1049
BM_SelectInputRange(::testing::benchmark::State & state)1050 void BM_SelectInputRange(::testing::benchmark::State& state) {
1051 // Create a Select NodeDef with 3 inputs.
1052 NodeDef node_def;
1053 node_def.set_name("select-op");
1054 node_def.set_op("Select");
1055 AttrValue attr_T;
1056 attr_T.set_type(DT_FLOAT);
1057 node_def.mutable_attr()->insert({"T", attr_T});
1058 for (size_t i = 0; i < 3; ++i) {
1059 node_def.add_input(strings::StrCat("a:", i));
1060 }
1061
1062 BM_InputRangeHelper(state, node_def, "condition", 0, 1);
1063 }
1064
BM_TraceString(::testing::benchmark::State & state)1065 void BM_TraceString(::testing::benchmark::State& state) {
1066 const int verbose = state.range(0);
1067
1068 // Create a MatMul NodeDef with 2 inputs.
1069 NodeDef node_def;
1070 node_def.set_name("gradient_tape/model_1/dense_1/MatMul_1");
1071 node_def.set_op("MatMul");
1072 AttrValue transpose_a, transpose_b, attr_t;
1073 attr_t.set_type(DT_FLOAT);
1074 node_def.mutable_attr()->insert({"T", attr_t});
1075 transpose_a.set_b(true);
1076 node_def.mutable_attr()->insert({"transpose_a", transpose_a});
1077 transpose_b.set_b(true);
1078 node_def.mutable_attr()->insert({"transpose_b", transpose_b});
1079 for (size_t i = 0; i < 2; ++i) {
1080 node_def.add_input(strings::StrCat("a:", i));
1081 }
1082
1083 // Build OpKernel and OpKernelContext
1084 Status status;
1085 auto device = absl::make_unique<DummyDevice>(Env::Default());
1086 std::unique_ptr<OpKernel> op(CreateOpKernel(DEVICE_CPU, device.get(),
1087 cpu_allocator(), node_def,
1088 TF_GRAPH_DEF_VERSION, &status));
1089 TF_CHECK_OK(status);
1090
1091 OpKernelContext::Params params;
1092 params.device = device.get();
1093 params.op_kernel = op.get();
1094 Tensor a(DT_FLOAT, TensorShape({99000, 256}));
1095 Tensor b(DT_FLOAT, TensorShape({256, 256}));
1096 gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&a), TensorValue(&b)};
1097 params.inputs = &inputs;
1098 auto ctx = absl::make_unique<OpKernelContext>(¶ms);
1099
1100 for (auto s : state) {
1101 auto trace = op->TraceString(*ctx, verbose);
1102 }
1103 }
1104
1105 BENCHMARK(BM_ConcatInputRange);
1106 BENCHMARK(BM_SelectInputRange);
1107 BENCHMARK(BM_TraceString)->Arg(1)->Arg(0);
1108
TEST(RegisteredKernels,CanCallGetAllRegisteredKernels)1109 TEST(RegisteredKernels, CanCallGetAllRegisteredKernels) {
1110 auto kernel_list = GetAllRegisteredKernels();
1111 auto all_registered_kernels = kernel_list.kernel();
1112 auto has_name_test1 = [](const KernelDef& k) { return k.op() == "Test1"; };
1113
1114 // Verify we can find the "Test1" op registered above
1115 auto test1_it = std::find_if(all_registered_kernels.begin(),
1116 all_registered_kernels.end(), has_name_test1);
1117 ASSERT_NE(test1_it, all_registered_kernels.end());
1118 EXPECT_EQ(test1_it->device_type(), "CPU");
1119
1120 // Verify there was just one kernel
1121 ++test1_it;
1122 EXPECT_EQ(
1123 std::find_if(test1_it, all_registered_kernels.end(), has_name_test1),
1124 all_registered_kernels.end());
1125 }
1126
1127 // Simple test just to check we can call LogAllRegisteredKernels
TEST(RegisteredKernels,CanLogAllRegisteredKernels)1128 TEST(RegisteredKernels, CanLogAllRegisteredKernels) {
1129 tensorflow::LogAllRegisteredKernels();
1130 }
1131
TEST(RegisteredKernels,GetFilteredRegisteredKernels)1132 TEST(RegisteredKernels, GetFilteredRegisteredKernels) {
1133 auto has_name_test1 = [](const KernelDef& k) { return k.op() == "Test1"; };
1134 auto kernel_list = GetFilteredRegisteredKernels(has_name_test1);
1135 ASSERT_EQ(kernel_list.kernel_size(), 1);
1136 EXPECT_EQ(kernel_list.kernel(0).op(), "Test1");
1137 EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
1138 }
1139
TEST(RegisteredKernels,GetRegisteredKernelsForOp)1140 TEST(RegisteredKernels, GetRegisteredKernelsForOp) {
1141 auto kernel_list = GetRegisteredKernelsForOp("Test1");
1142 ASSERT_EQ(kernel_list.kernel_size(), 1);
1143 EXPECT_EQ(kernel_list.kernel(0).op(), "Test1");
1144 EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
1145 }
1146
1147 // EXTRACT_KERNEL_NAME_TO_STRING wraps TF_EXTRACT_KERNEL_NAME for testing
1148 // (it involves quite a bit of macro-magic).
1149 #define EXTRACT_KERNEL_NAME_TO_STRING_IMPL(name, kernel_builder, ...) name
1150 #define EXTRACT_KERNEL_NAME_TO_STRING(kernel_builder) \
1151 TF_EXTRACT_KERNEL_NAME(EXTRACT_KERNEL_NAME_TO_STRING_IMPL, kernel_builder)
1152
TEST(RegisterKernelMacro,ExtractName)1153 TEST(RegisterKernelMacro, ExtractName) {
1154 static constexpr char const* kName = "Foo";
1155 static constexpr char const* kExtractedName =
1156 EXTRACT_KERNEL_NAME_TO_STRING(Name(kName).Label("Label"));
1157 EXPECT_THAT(kExtractedName, ::testing::StrEq(kName));
1158 }
1159
1160 } // namespace
1161 } // namespace tensorflow
1162