• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
16 #define EIGEN_USE_GPU
17 #endif
18 
19 #include "tensorflow/c/kernels.h"
20 
21 #include <stddef.h>
22 #include <stdint.h>
23 #include <string.h>
24 
25 #include <memory>
26 #include <string>
27 #include <utility>
28 
29 #include "absl/container/inlined_vector.h"
30 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
31 #include "tensorflow/c/c_api.h"
32 #include "tensorflow/c/tf_datatype.h"
33 #include "tensorflow/c/tf_status.h"
34 #include "tensorflow/c/tf_tensor.h"
35 #include "tensorflow/core/common_runtime/device.h"
36 #include "tensorflow/core/common_runtime/device_factory.h"
37 #include "tensorflow/core/framework/allocator.h"
38 #include "tensorflow/core/framework/attr_value.pb.h"
39 #include "tensorflow/core/framework/device_base.h"
40 #include "tensorflow/core/framework/kernel_def.pb.h"
41 #include "tensorflow/core/framework/node_def.pb.h"
42 #include "tensorflow/core/framework/node_def_builder.h"
43 #include "tensorflow/core/framework/op.h"
44 #include "tensorflow/core/framework/op_kernel.h"
45 #include "tensorflow/core/framework/tensor.h"
46 #include "tensorflow/core/framework/tensor_types.h"
47 #include "tensorflow/core/framework/types.h"
48 #include "tensorflow/core/framework/types.pb.h"
49 #include "tensorflow/core/kernels/ops_testutil.h"
50 #include "tensorflow/core/lib/core/status_test_util.h"
51 #include "tensorflow/core/platform/env.h"
52 #include "tensorflow/core/platform/status.h"
53 #include "tensorflow/core/platform/test.h"
54 #include "tensorflow/core/platform/types.h"
55 
56 struct MyCustomKernel {
57   bool created;
58   bool compute_called;
59 };
60 
61 static bool delete_called = false;
62 
MyCreateFunc(TF_OpKernelConstruction * ctx)63 static void* MyCreateFunc(TF_OpKernelConstruction* ctx) {
64   struct MyCustomKernel* s = new struct MyCustomKernel;
65   s->created = true;
66   s->compute_called = false;
67 
68   // Exercise attribute reads.
69   TF_DataType type;
70   TF_Status* status = TF_NewStatus();
71   TF_OpKernelConstruction_GetAttrType(ctx, "SomeDataTypeAttr", &type, status);
72   EXPECT_EQ(TF_OK, TF_GetCode(status));
73   EXPECT_EQ(TF_FLOAT, type);
74   TF_DeleteStatus(status);
75 
76   // Exercise kernel NodeDef name read
77   TF_StringView name_string_view = TF_OpKernelConstruction_GetName(ctx);
78   std::string node_name = "SomeNodeName";
79   std::string candidate_node_name =
80       std::string(name_string_view.data, name_string_view.len);
81   EXPECT_EQ(node_name, candidate_node_name);
82   return s;
83 }
84 
MyComputeFunc(void * kernel,TF_OpKernelContext * ctx)85 static void MyComputeFunc(void* kernel, TF_OpKernelContext* ctx) {
86   struct MyCustomKernel* s = static_cast<struct MyCustomKernel*>(kernel);
87   s->compute_called = true;
88   if (ctx != nullptr) {
89     EXPECT_EQ(43, TF_StepId(ctx));
90   }
91 }
92 
MyDeleteFunc(void * kernel)93 static void MyDeleteFunc(void* kernel) {
94   struct MyCustomKernel* s = static_cast<struct MyCustomKernel*>(kernel);
95   EXPECT_TRUE(s->created);
96   EXPECT_TRUE(s->compute_called);
97   delete_called = true;
98   delete s;
99 }
100 
101 namespace tensorflow {
102 
GetFakeKernel(const char * device_name,const char * op_name,const char * node_name,Status * status)103 static std::unique_ptr<OpKernel> GetFakeKernel(const char* device_name,
104                                                const char* op_name,
105                                                const char* node_name,
106                                                Status* status) {
107   NodeDef def;
108   def.set_op(op_name);
109   def.set_name(node_name);
110   def.set_device(device_name);
111   def.add_input("input1");
112   def.add_input("input2");
113 
114   AttrValue v;
115   v.set_type(DataType::DT_FLOAT);
116   (*def.mutable_attr())["SomeDataTypeAttr"] = v;
117 
118   return CreateOpKernel(DeviceType(device_name), nullptr, nullptr, def, 1,
119                         status);
120 }
121 
122 // Tests registration of a single C kernel and checks that calls through the
123 // C/C++ boundary are being made.
TEST(TestKernel,TestRegisterKernelBuilder)124 TEST(TestKernel, TestRegisterKernelBuilder) {
125   const char* node_name = "SomeNodeName";
126   const char* op_name = "FooOp";
127   const char* device_name = "FakeDeviceName1";
128 
129   REGISTER_OP(op_name)
130       .Input("input1: double")
131       .Input("input2: uint8")
132       .Output("output1: uint8")
133       .Attr("SomeDataTypeAttr: type");
134 
135   TF_KernelBuilder* builder = TF_NewKernelBuilder(
136       op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc);
137 
138   {
139     TF_Status* status = TF_NewStatus();
140     TF_RegisterKernelBuilder(node_name, builder, status);
141     EXPECT_EQ(TF_OK, TF_GetCode(status));
142     TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
143     EXPECT_EQ(TF_OK, TF_GetCode(status));
144     KernelList list;
145     list.ParseFromArray(buf->data, buf->length);
146     ASSERT_EQ(1, list.kernel_size());
147     ASSERT_EQ(device_name, list.kernel(0).device_type());
148     TF_DeleteBuffer(buf);
149     TF_DeleteStatus(status);
150   }
151 
152   {
153     Status status;
154     std::unique_ptr<OpKernel> kernel =
155         GetFakeKernel(device_name, op_name, node_name, &status);
156     TF_EXPECT_OK(status);
157     ASSERT_NE(nullptr, kernel.get());
158     kernel->Compute(nullptr);
159   }
160 
161   ASSERT_TRUE(delete_called);
162 }
163 
164 // REGISTER_OP for TF_OpKernelConstruction_GetAttr* test cases.
165 // Registers two ops, each with a single attribute called 'Attr'.
166 // The attribute in one op will have a type 'type', the other
167 // will have list(type).
168 #define ATTR_TEST_REGISTER_OP(name, type)                     \
169   REGISTER_OP("TestKernelAttr" #name)                         \
170       .Attr("Attr: " #type)                                   \
171       .SetShapeFn(tensorflow::shape_inference::UnknownShape); \
172   REGISTER_OP("TestKernelAttr" #name "List")                  \
173       .Attr("Attr: list(" #type ")")                          \
174       .SetShapeFn(tensorflow::shape_inference::UnknownShape)
175 ATTR_TEST_REGISTER_OP(String, string);
176 ATTR_TEST_REGISTER_OP(Int, int);
177 ATTR_TEST_REGISTER_OP(Float, float);
178 ATTR_TEST_REGISTER_OP(Bool, bool);
179 ATTR_TEST_REGISTER_OP(Type, type);
180 #undef ATTR_TEST_REGISTER_OP
181 
182 // Helper macros for the TF_OpKernelConstruction_GetAttr* tests.
183 #define EXPECT_TF_SIZE(attr_name, expected_list_size, expected_total_size) \
184   do {                                                                     \
185     int32_t list_size, total_size;                                         \
186     TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size,        \
187                                         &total_size, status);              \
188     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);            \
189     EXPECT_EQ(expected_list_size, list_size);                              \
190     EXPECT_EQ(expected_total_size, total_size);                            \
191   } while (0)
192 
193 typedef void* (*MyCreateFuncWithAttr)(TF_OpKernelConstruction*);
194 class TestKernelAttr : public ::testing::Test {
195  public:
TestKernelAttr()196   TestKernelAttr() {}
~TestKernelAttr()197   ~TestKernelAttr() override {}
198 
GetFakeKernelWithAttr(const char * op_name,AttrValue v,Status * status)199   std::unique_ptr<OpKernel> GetFakeKernelWithAttr(const char* op_name,
200                                                   AttrValue v, Status* status) {
201     NodeDef def;
202     def.set_op(op_name);
203     def.set_name("FakeNode");
204     def.set_device("FakeDevice");
205     (*def.mutable_attr())["Attr"] = v;
206     return CreateOpKernel(DeviceType("FakeDevice"), nullptr, nullptr, def, 1,
207                           status);
208   }
209 
CreateAndCallKernelWithAttr(MyCreateFuncWithAttr MyCreateFuncAttr,const char * op_name,AttrValue & v)210   void CreateAndCallKernelWithAttr(MyCreateFuncWithAttr MyCreateFuncAttr,
211                                    const char* op_name, AttrValue& v) {
212     TF_KernelBuilder* builder = TF_NewKernelBuilder(
213         op_name, "FakeDevice", MyCreateFuncAttr, &MyComputeFunc, &MyDeleteFunc);
214     {
215       TF_Status* status = TF_NewStatus();
216       TF_RegisterKernelBuilder("FakeNode", builder, status);
217       EXPECT_EQ(TF_OK, TF_GetCode(status));
218       TF_DeleteStatus(status);
219     }
220     Status status;
221     std::unique_ptr<OpKernel> kernel =
222         GetFakeKernelWithAttr(op_name, v, &status);
223     TF_EXPECT_OK(status);
224     ASSERT_NE(nullptr, kernel.get());
225     kernel->Compute(nullptr);
226 
227     ASSERT_TRUE(delete_called);
228   }
229 };
230 
TEST_F(TestKernelAttr,String)231 TEST_F(TestKernelAttr, String) {
232   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
233     struct MyCustomKernel* s = new struct MyCustomKernel;
234     s->created = true;
235     s->compute_called = false;
236 
237     std::unique_ptr<char[]> val(new char[5]);
238     TF_Status* status = TF_NewStatus();
239     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
240                    /*expected_total_size*/ 5);
241     TF_OpKernelConstruction_GetAttrString(ctx, "Attr", val.get(),
242                                           /*max_length*/ 5, status);
243 
244     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
245     EXPECT_EQ("bunny", string(static_cast<const char*>(val.get()), 5));
246     TF_DeleteStatus(status);
247     return static_cast<void*>(s);
248   };
249 
250   AttrValue v;
251   v.set_s("bunny");
252   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrString", v);
253 }
254 
TEST_F(TestKernelAttr,StringList)255 TEST_F(TestKernelAttr, StringList) {
256   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
257     struct MyCustomKernel* s = new struct MyCustomKernel;
258     s->created = true;
259     s->compute_called = false;
260 
261     std::vector<string> list = {"bugs", "bunny", "duck"};
262     int list_total_size = 0;
263     for (const auto& s : list) {
264       list_total_size += s.size();
265     }
266 
267     TF_Status* status = TF_NewStatus();
268     std::unique_ptr<char*[]> values(new char*[list.size()]);
269     std::unique_ptr<size_t[]> lens(new size_t[list.size()]);
270     std::unique_ptr<char[]> storage(new char[list_total_size]);
271     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list.size(),
272                    /*expected_total_size*/ list_total_size);
273     TF_OpKernelConstruction_GetAttrStringList(
274         ctx, "Attr", values.get(), lens.get(), list.size(), storage.get(),
275         list_total_size, status);
276     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
277 
278     for (size_t i = 0; i < list.size(); ++i) {
279       EXPECT_EQ(list[i].size(), lens[i]) << i;
280       EXPECT_EQ(list[i], string(static_cast<const char*>(values[i]), lens[i]))
281           << i;
282     }
283     TF_DeleteStatus(status);
284     return static_cast<void*>(s);
285   };
286 
287   AttrValue v;
288   std::string attr_in[] = {"bugs", "bunny", "duck"};
289   SetAttrValue(gtl::ArraySlice<std::string>(attr_in, 3), &v);
290   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrStringList", v);
291 }
292 
TEST_F(TestKernelAttr,Int)293 TEST_F(TestKernelAttr, Int) {
294   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
295     struct MyCustomKernel* s = new struct MyCustomKernel;
296     s->created = true;
297     s->compute_called = false;
298 
299     int64_t val;
300     TF_Status* status = TF_NewStatus();
301     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
302                    /*expected_total_size*/ -1);
303     TF_OpKernelConstruction_GetAttrInt64(ctx, "Attr", &val, status);
304     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
305     EXPECT_EQ(1234, val);
306     TF_DeleteStatus(status);
307     return static_cast<void*>(s);
308   };
309 
310   AttrValue v;
311   v.set_i(1234);
312   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrInt", v);
313 }
314 
TEST_F(TestKernelAttr,IntList)315 TEST_F(TestKernelAttr, IntList) {
316   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
317     struct MyCustomKernel* s = new struct MyCustomKernel;
318     s->created = true;
319     s->compute_called = false;
320 
321     const int64_t list[] = {1, 2, 3, 4};
322     const size_t list_size = TF_ARRAYSIZE(list);
323     int64_t values[list_size];
324 
325     TF_Status* status = TF_NewStatus();
326     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
327                    /*expected_total_size*/ -1);
328     TF_OpKernelConstruction_GetAttrInt64List(ctx, "Attr", values, list_size,
329                                              status);
330     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
331     EXPECT_TRUE(
332         std::equal(std::begin(list), std::end(list), std::begin(values)));
333     TF_DeleteStatus(status);
334     return static_cast<void*>(s);
335   };
336 
337   AttrValue v;
338   int64 attr_in[] = {1, 2, 3, 4};
339   SetAttrValue(gtl::ArraySlice<int64>(attr_in, 4), &v);
340   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrIntList", v);
341 }
342 
TEST_F(TestKernelAttr,Float)343 TEST_F(TestKernelAttr, Float) {
344   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
345     struct MyCustomKernel* s = new struct MyCustomKernel;
346     s->created = true;
347     s->compute_called = false;
348 
349     float val;
350     TF_Status* status = TF_NewStatus();
351     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
352                    /*expected_total_size*/ -1);
353     TF_OpKernelConstruction_GetAttrFloat(ctx, "Attr", &val, status);
354     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
355     EXPECT_FLOAT_EQ(2.718, val);
356     TF_DeleteStatus(status);
357     return static_cast<void*>(s);
358   };
359 
360   AttrValue v;
361   v.set_f(2.718);
362   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrFloat", v);
363 }
364 
TEST_F(TestKernelAttr,FloatList)365 TEST_F(TestKernelAttr, FloatList) {
366   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
367     struct MyCustomKernel* s = new struct MyCustomKernel;
368     s->created = true;
369     s->compute_called = false;
370 
371     const float list[] = {1.414, 2.718, 3.1415};
372     const size_t list_size = TF_ARRAYSIZE(list);
373     float values[list_size];
374 
375     TF_Status* status = TF_NewStatus();
376     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
377                    /*expected_total_size*/ -1);
378     TF_OpKernelConstruction_GetAttrFloatList(ctx, "Attr", values, list_size,
379                                              status);
380     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
381     EXPECT_TRUE(
382         std::equal(std::begin(list), std::end(list), std::begin(values)));
383     TF_DeleteStatus(status);
384     return static_cast<void*>(s);
385   };
386 
387   AttrValue v;
388   float attr_in[] = {1.414, 2.718, 3.1415};
389   SetAttrValue(gtl::ArraySlice<float>(attr_in, 3), &v);
390   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrFloatList", v);
391 }
392 
TEST_F(TestKernelAttr,Bool)393 TEST_F(TestKernelAttr, Bool) {
394   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
395     struct MyCustomKernel* s = new struct MyCustomKernel;
396     s->created = true;
397     s->compute_called = false;
398 
399     unsigned char val;
400     TF_Status* status = TF_NewStatus();
401     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
402                    /*expected_total_size*/ -1);
403     TF_OpKernelConstruction_GetAttrBool(ctx, "Attr", &val, status);
404     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
405     EXPECT_EQ(1, val);
406     TF_DeleteStatus(status);
407     return static_cast<void*>(s);
408   };
409 
410   AttrValue v;
411   v.set_b(true);
412   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrBool", v);
413 }
414 
TEST_F(TestKernelAttr,BoolList)415 TEST_F(TestKernelAttr, BoolList) {
416   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
417     struct MyCustomKernel* s = new struct MyCustomKernel;
418     s->created = true;
419     s->compute_called = false;
420 
421     const unsigned char list[] = {1, 0, 1, 0};
422     const size_t list_size = TF_ARRAYSIZE(list);
423     unsigned char values[list_size];
424 
425     TF_Status* status = TF_NewStatus();
426     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
427                    /*expected_total_size*/ -1);
428     TF_OpKernelConstruction_GetAttrBoolList(ctx, "Attr", values, list_size,
429                                             status);
430     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
431     EXPECT_TRUE(
432         std::equal(std::begin(list), std::end(list), std::begin(values)));
433     TF_DeleteStatus(status);
434     return static_cast<void*>(s);
435   };
436 
437   AttrValue v;
438   bool attr_in[] = {true, false, true, false};
439   SetAttrValue(gtl::ArraySlice<bool>(attr_in, 4), &v);
440   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrBoolList", v);
441 }
442 
TEST_F(TestKernelAttr,Type)443 TEST_F(TestKernelAttr, Type) {
444   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
445     struct MyCustomKernel* s = new struct MyCustomKernel;
446     s->created = true;
447     s->compute_called = false;
448 
449     TF_DataType val;
450     TF_Status* status = TF_NewStatus();
451     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
452                    /*expected_total_size*/ -1);
453     TF_OpKernelConstruction_GetAttrType(ctx, "Attr", &val, status);
454     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
455     EXPECT_EQ(TF_FLOAT, val);
456     TF_DeleteStatus(status);
457     return static_cast<void*>(s);
458   };
459 
460   AttrValue v;
461   v.set_type(DT_FLOAT);
462   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrType", v);
463 }
464 
TEST_F(TestKernelAttr,TypeList)465 TEST_F(TestKernelAttr, TypeList) {
466   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
467     struct MyCustomKernel* s = new struct MyCustomKernel;
468     s->created = true;
469     s->compute_called = false;
470 
471     const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128};
472     const size_t list_size = TF_ARRAYSIZE(list);
473     TF_DataType values[list_size];
474 
475     TF_Status* status = TF_NewStatus();
476     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
477                    /*expected_total_size*/ -1);
478     TF_OpKernelConstruction_GetAttrTypeList(ctx, "Attr", values, list_size,
479                                             status);
480     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
481     EXPECT_TRUE(
482         std::equal(std::begin(list), std::end(list), std::begin(values)));
483     TF_DeleteStatus(status);
484     return static_cast<void*>(s);
485   };
486 
487   AttrValue v;
488   DataType attr_in[] = {DT_FLOAT, DT_DOUBLE, DT_HALF, DT_COMPLEX128};
489   SetAttrValue(gtl::ArraySlice<DataType>(attr_in, 4), &v);
490   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrTypeList", v);
491 }
492 #undef EXPECT_TF_SIZE
493 
494 class DummyDevice : public DeviceBase {
495  public:
DummyDevice(Env * env)496   explicit DummyDevice(Env* env) : DeviceBase(env) {}
GetAllocator(AllocatorAttributes)497   Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
498     return cpu_allocator();
499   }
500 };
501 
TEST(TestKernel,TestInputAndOutputCount)502 TEST(TestKernel, TestInputAndOutputCount) {
503   const char* node_name = "InputOutputCounterKernel";
504   const char* op_name = "BarOp";
505   const char* device_name = "FakeDeviceName2";
506 
507   REGISTER_OP(op_name)
508       .Input("input1: double")
509       .Input("input2: uint8")
510       .Output("output1: uint8")
511       .Attr("SomeDataTypeAttr: type");
512 
513   static int num_inputs = 0;
514   static int num_outputs = 0;
515 
516   // A kernel whose Compute function has a side-effect of updating num_inputs
517   // and num_outputs. Various functions on TF_OpKernelContext are also
518   // exercised.
519   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
520     num_inputs = TF_NumInputs(ctx);
521     num_outputs = TF_NumOutputs(ctx);
522 
523     TF_Tensor* input = nullptr;
524     TF_Status* s = TF_NewStatus();
525     TF_GetInput(ctx, 0, &input, s);
526     EXPECT_EQ(TF_OK, TF_GetCode(s)) << "Failed to get input: " << TF_Message(s);
527     EXPECT_EQ(123, *static_cast<tensorflow::uint8*>(TF_TensorData(input)));
528     TF_GetInput(ctx, -1, &input, s);
529     EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
530     TF_GetInput(ctx, 3, &input, s);
531     EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
532 
533     // Copy the input tensor to output.
534     TF_SetOutput(ctx, 0, input, s);
535     EXPECT_EQ(TF_OK, TF_GetCode(s));
536 
537     TF_SetOutput(ctx, 24, input, s);
538     EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
539 
540     EXPECT_EQ(TF_UINT8, TF_ExpectedOutputDataType(ctx, 0));
541 
542     TF_DeleteStatus(s);
543     if (input != nullptr) {
544       TF_DeleteTensor(input);
545     }
546   };
547 
548   TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr,
549                                                   my_compute_func, nullptr);
550 
551   {
552     TF_Status* status = TF_NewStatus();
553     TF_RegisterKernelBuilder(node_name, builder, status);
554     EXPECT_EQ(TF_OK, TF_GetCode(status));
555     TF_DeleteStatus(status);
556   }
557 
558   {
559     OpKernelContext::Params p;
560     DummyDevice dummy_device(nullptr);
561     p.device = &dummy_device;
562     p.step_id = 43;
563 
564     Tensor t(tensorflow::uint8(123));
565 
566     gtl::InlinedVector<TensorValue, 4> inputs;
567     // Simulate 2 inputs
568     inputs.emplace_back(&t);
569     inputs.emplace_back();
570     p.inputs = &inputs;
571 
572     Status status;
573     std::unique_ptr<OpKernel> kernel =
574         GetFakeKernel(device_name, op_name, node_name, &status);
575     TF_EXPECT_OK(status);
576     ASSERT_NE(nullptr, kernel.get());
577 
578     p.op_kernel = kernel.get();
579     OpKernelContext ctx(&p);
580     kernel->Compute(&ctx);
581 
582     ASSERT_EQ(2, num_inputs);
583     ASSERT_EQ(1, num_outputs);
584     ASSERT_EQ(123, ctx.mutable_output(0)->scalar<tensorflow::uint8>()());
585   }
586 }
587 
TEST(TestKernel,DeleteKernelBuilderIsOkOnNull)588 TEST(TestKernel, DeleteKernelBuilderIsOkOnNull) {
589   TF_DeleteKernelBuilder(nullptr);
590 }
591 
TEST(TestKernel,TestTypeConstraint)592 TEST(TestKernel, TestTypeConstraint) {
593   const char* node_name = "SomeNodeName";
594   const char* op_name = "TypeOp";
595   const char* device_name = "FakeDeviceName1";
596 
597   REGISTER_OP(op_name)
598       .Input("input1: double")
599       .Input("input2: uint8")
600       .Output("output1: uint8")
601       .Attr("T: type");
602 
603   TF_KernelBuilder* builder = TF_NewKernelBuilder(
604       op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc);
605   TF_Status* status = TF_NewStatus();
606   TF_KernelBuilder_TypeConstraint(builder, "T", TF_DataType::TF_INT32, status);
607   EXPECT_EQ(TF_OK, TF_GetCode(status));
608   TF_RegisterKernelBuilder(node_name, builder, status);
609   EXPECT_EQ(TF_OK, TF_GetCode(status));
610 
611   TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
612   EXPECT_EQ(TF_OK, TF_GetCode(status));
613   KernelList list;
614   list.ParseFromArray(buf->data, buf->length);
615   const auto expected_str = R"str(kernel {
616   op: "TypeOp"
617   device_type: "FakeDeviceName1"
618   constraint {
619     name: "T"
620     allowed_values {
621       list {
622         type: DT_INT32
623       }
624     }
625   }
626 }
627 )str";
628   ASSERT_EQ(expected_str, list.DebugString());
629 
630   TF_DeleteBuffer(buf);
631   TF_DeleteStatus(status);
632   TF_DeleteKernelBuilder(builder);
633   ASSERT_TRUE(delete_called);
634 }
635 
TEST(TestKernel,TestHostMemory)636 TEST(TestKernel, TestHostMemory) {
637   const char* node_name = "SomeNodeName";
638   const char* op_name = "HostMemoryOp";
639   const char* device_name = "FakeDeviceName1";
640 
641   REGISTER_OP(op_name)
642       .Input("input1: double")
643       .Input("input2: uint8")
644       .Output("output1: uint8")
645       .Attr("T: type");
646 
647   TF_KernelBuilder* builder = TF_NewKernelBuilder(
648       op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc);
649   TF_KernelBuilder_HostMemory(builder, "input2");
650   TF_KernelBuilder_HostMemory(builder, "output1");
651   TF_Status* status = TF_NewStatus();
652   TF_RegisterKernelBuilder(node_name, builder, status);
653   EXPECT_EQ(TF_OK, TF_GetCode(status));
654 
655   TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
656   EXPECT_EQ(TF_OK, TF_GetCode(status));
657   KernelList list;
658   list.ParseFromArray(buf->data, buf->length);
659   const auto expected_str = R"str(kernel {
660   op: "HostMemoryOp"
661   device_type: "FakeDeviceName1"
662   host_memory_arg: "input2"
663   host_memory_arg: "output1"
664 }
665 )str";
666   ASSERT_EQ(expected_str, list.DebugString());
667 
668   TF_DeleteBuffer(buf);
669   TF_DeleteStatus(status);
670   TF_DeleteKernelBuilder(builder);
671   ASSERT_TRUE(delete_called);
672 }
673 
674 class DeviceKernelOpTest : public OpsTestBase {
675  protected:
SetupOp(const char * op_name,const char * node_name,void (* compute_func)(void *,TF_OpKernelContext *))676   void SetupOp(const char* op_name, const char* node_name,
677                void (*compute_func)(void*, TF_OpKernelContext*)) {
678     TF_KernelBuilder* builder = TF_NewKernelBuilder(
679         op_name, device_name_, nullptr, compute_func, nullptr);
680     TF_Status* status = TF_NewStatus();
681     TF_RegisterKernelBuilder(node_name, builder, status);
682     EXPECT_EQ(TF_OK, TF_GetCode(status));
683     TF_DeleteStatus(status);
684 
685 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
686     std::unique_ptr<Device> device(
687         DeviceFactory::NewDevice(device_name_, {}, "/job:a/replica:0/task:0"));
688     OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
689 #endif
690     TF_ASSERT_OK(NodeDefBuilder(op_name, op_name).Finalize(node_def()));
691     TF_ASSERT_OK(InitOp());
692   }
693 
694 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
695   const char* device_name_ = tensorflow::DEVICE_GPU;
696 #else
697   const char* device_name_ = tensorflow::DEVICE_CPU;
698 #endif
699 };
700 
701 // Validates that the tensor has shape and type corresponding to
702 // dims and dtype.
703 void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims,
704                      TF_DataType dtype);
705 
706 // Copies data of length tensor_size_bytes from values to tensor.
707 template <typename T>
708 void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
709                      TF_OpKernelContext* ctx);
710 
711 REGISTER_OP("StreamOp").Output("output1: float");
712 
TEST_F(DeviceKernelOpTest,TestStream)713 TEST_F(DeviceKernelOpTest, TestStream) {
714   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
715     TF_Status* s = TF_NewStatus();
716     SP_Stream stream = TF_GetStream(ctx, s);
717     // Stream is always null if device is not a pluggable device. More test
718     // cases will be added when pluggable device mechanism is supported.
719     EXPECT_EQ(stream, nullptr);
720     EXPECT_NE(TF_OK, TF_GetCode(s));
721     TF_DeleteStatus(s);
722   };
723 
724   SetupOp("StreamOp", "StreamOp", my_compute_func);
725   TF_ASSERT_OK(RunOpKernel());
726 }
727 
728 REGISTER_OP("AllocateOutputOp1").Output("output1: float");
729 
TEST_F(DeviceKernelOpTest,TestAllocateOutputSizeOne)730 TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
731   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
732     // Allocate output
733     TF_Status* s = TF_NewStatus();
734     int64_t dim = 1;
735     size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT);
736     TF_Tensor* output = TF_AllocateOutput(
737         /*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
738         /*num_dims=*/1, /*len=*/tensor_size_bytes, s);
739     validate_tensor(output, &dim, 1, TF_FLOAT);
740 
741     // Set output to 3
742     float values[1] = {3.0f};
743     set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
744     TF_DeleteStatus(s);
745     TF_DeleteTensor(output);
746   };
747 
748   SetupOp("AllocateOutputOp1", "AllocateOutput1", my_compute_func);
749 
750   TF_ASSERT_OK(RunOpKernel());
751   Tensor* output = GetOutput(0);
752   EXPECT_EQ("Tensor<type: float shape: [1] values: 3>",
753             output->DebugString(100));
754 }
755 
756 REGISTER_OP("AllocateOutputOp0").Output("output1: float");
757 
TEST_F(DeviceKernelOpTest,TestAllocateEmptyOutput)758 TEST_F(DeviceKernelOpTest, TestAllocateEmptyOutput) {
759   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
760     TF_Status* s = TF_NewStatus();
761     // Allocate empty output
762     int64_t dim = 0;
763     TF_Tensor* output = TF_AllocateOutput(
764         /*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
765         /*num_dims=*/1, /*len=*/0, s);
766     EXPECT_EQ(TF_OK, TF_GetCode(s));
767     validate_tensor(output, &dim, 1, TF_FLOAT);
768     TF_DeleteStatus(s);
769     TF_DeleteTensor(output);
770   };
771 
772   SetupOp("AllocateOutputOp0", "AllocateOutput0", my_compute_func);
773 
774   TF_ASSERT_OK(RunOpKernel());
775   Tensor* output = GetOutput(0);
776   EXPECT_EQ("Tensor<type: float shape: [0] values: >",
777             output->DebugString(100));
778 }
779 
780 REGISTER_OP("AllocateOutputOp2x3").Output("output1: float");
781 
TEST_F(DeviceKernelOpTest,TestAllocateOutputSize2x3)782 TEST_F(DeviceKernelOpTest, TestAllocateOutputSize2x3) {
783   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
784     TF_Status* s = TF_NewStatus();
785     // Allocate 2x3 output
786     int64_t dim[2] = {2, 3};
787     size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT) * 6;
788     TF_Tensor* output = TF_AllocateOutput(
789         /*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/dim,
790         /*num_dims=*/2, /*len=*/tensor_size_bytes, s);
791     EXPECT_EQ(TF_OK, TF_GetCode(s));
792     validate_tensor(output, dim, 2, TF_FLOAT);
793 
794     // Set output to [1 2 3 4 5 6]
795     float values[6] = {1, 2, 3, 4, 5, 6};
796     set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
797     TF_DeleteStatus(s);
798     TF_DeleteTensor(output);
799   };
800 
801   SetupOp("AllocateOutputOp2x3", "AllocateOutput2x3", my_compute_func);
802 
803   TF_ASSERT_OK(RunOpKernel());
804   Tensor* output = GetOutput(0);
805   EXPECT_EQ("Tensor<type: float shape: [2,3] values: [1 2 3][4 5 6]>",
806             output->DebugString(100));
807 }
808 
809 REGISTER_OP("AllocateTempOp1").Output("output1: float");
810 
TEST_F(DeviceKernelOpTest,TestAllocateTempSizeOne)811 TEST_F(DeviceKernelOpTest, TestAllocateTempSizeOne) {
812   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
813     // Allocate scalar TF_Tensor
814     TF_Status* s = TF_NewStatus();
815     int64_t dim = 1;
816     TF_AllocatorAttributes alloc_attrs;
817     alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
818 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
819     alloc_attrs.on_host = 0;
820 #else
821     alloc_attrs.on_host = 1;
822 #endif
823     TF_Tensor* output = TF_AllocateTemp(
824         /*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
825         /*num_dims=*/1, /*allocator_attributes*/ &alloc_attrs, s);
826     size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT);
827     EXPECT_EQ(TF_OK, TF_GetCode(s));
828     validate_tensor(output, &dim, 1, TF_FLOAT);
829 
830     // Set TF_Tensor value to 3
831     float values[1] = {3.0f};
832     set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
833     TF_SetOutput(ctx, 0, output, s);
834     TF_DeleteStatus(s);
835     TF_DeleteTensor(output);
836   };
837 
838   SetupOp("AllocateTempOp1", "AllocateTemp1", my_compute_func);
839 
840   TF_ASSERT_OK(RunOpKernel());
841   Tensor* output = GetOutput(0);
842   EXPECT_EQ("Tensor<type: float shape: [1] values: 3>",
843             output->DebugString(100));
844 }
845 
846 REGISTER_OP("AllocateTempOp0").Output("output1: float");
847 
TEST_F(DeviceKernelOpTest,TestAllocateTempEmpty)848 TEST_F(DeviceKernelOpTest, TestAllocateTempEmpty) {
849   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
850     TF_Status* s = TF_NewStatus();
851     // Allocate empty TF_Tensor
852     int64_t dim = 0;
853     TF_AllocatorAttributes alloc_attrs;
854     alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
855 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
856     alloc_attrs.on_host = 0;
857 #else
858     alloc_attrs.on_host = 1;
859 #endif
860     TF_Tensor* output = TF_AllocateTemp(
861         /*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
862         /*num_dims=*/1, /*allocator_attributes*/ &alloc_attrs, s);
863     EXPECT_EQ(TF_OK, TF_GetCode(s));
864     validate_tensor(output, &dim, 1, TF_FLOAT);
865     TF_SetOutput(ctx, 0, output, s);
866     TF_DeleteStatus(s);
867     TF_DeleteTensor(output);
868   };
869 
870   SetupOp("AllocateTempOp0", "AllocateTemp0", my_compute_func);
871 
872   TF_ASSERT_OK(RunOpKernel());
873   Tensor* output = GetOutput(0);
874   EXPECT_EQ("Tensor<type: float shape: [0] values: >",
875             output->DebugString(100));
876 }
877 
878 REGISTER_OP("AllocateTempOp2x3").Output("output1: float");
879 
TEST_F(DeviceKernelOpTest,TestAllocateTempSize2x3)880 TEST_F(DeviceKernelOpTest, TestAllocateTempSize2x3) {
881   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
882     TF_Status* s = TF_NewStatus();
883     size_t tensor_size_bytes = 6 * TF_DataTypeSize(TF_FLOAT);
884     // Allocate 2x3 TF_Tensor
885     int64_t dim[2] = {2, 3};
886     TF_AllocatorAttributes alloc_attrs;
887     alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
888 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
889     alloc_attrs.on_host = 0;
890 #else
891     alloc_attrs.on_host = 1;
892 #endif
893     TF_Tensor* output = TF_AllocateTemp(
894         /*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/dim,
895         /*num_dims=*/2, /*allocator_attributes*/ &alloc_attrs, s);
896     EXPECT_EQ(TF_OK, TF_GetCode(s));
897     validate_tensor(output, dim, 2, TF_FLOAT);
898 
899     // Set TF_Tensor values to [1 2 3 4 5 6]
900     float values[6] = {1, 2, 3, 4, 5, 6};
901     set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
902     TF_SetOutput(ctx, 0, output, s);
903     TF_DeleteStatus(s);
904     TF_DeleteTensor(output);
905   };
906 
907   SetupOp("AllocateTempOp2x3", "AllocateTempOp2x3", my_compute_func);
908 
909   TF_ASSERT_OK(RunOpKernel());
910   Tensor* output = GetOutput(0);
911   EXPECT_EQ("Tensor<type: float shape: [2,3] values: [1 2 3][4 5 6]>",
912             output->DebugString(100));
913 }
914 
TEST_F(DeviceKernelOpTest,TestForwardInputOrAllocateOutput)915 TEST_F(DeviceKernelOpTest, TestForwardInputOrAllocateOutput) {
916   const char* node_name = "TestForwardInputOrAllocateOutputKernel";
917   const char* op_name = "BazOp";
918   const char* device_name = "FakeDeviceName";
919 
920   REGISTER_OP(op_name)
921       .Input("input1: float")
922       .Input("input2: float")
923       .Output("output1: float")
924       .Attr("SomeDataTypeAttr: type");
925 
926   // A kernel whose Compute function that forwards a scalar input to output
927   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
928     TF_Status* s = TF_NewStatus();
929     int candidate_input_indices[1] = {0};
930     int forwarded_input;
931     int64_t output_dims[1] = {};
932     TF_Tensor* output = TF_ForwardInputOrAllocateOutput(
933         /*context=*/ctx, candidate_input_indices,
934         /*num_candidate_input_indices=*/1,
935         /*output_index=*/0, output_dims, /*output_num_dims=*/0,
936         &forwarded_input, /*status=*/s);
937     EXPECT_EQ(TF_OK, TF_GetCode(s));
938     EXPECT_EQ(forwarded_input, 0);
939     EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
940     EXPECT_EQ(0, TF_NumDims(output));
941     TF_DeleteStatus(s);
942     TF_DeleteTensor(output);
943   };
944 
945   TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr,
946                                                   my_compute_func, nullptr);
947 
948   {
949     TF_Status* status = TF_NewStatus();
950     TF_RegisterKernelBuilder(node_name, builder, status);
951     EXPECT_EQ(TF_OK, TF_GetCode(status));
952     TF_DeleteStatus(status);
953   }
954 
955   {
956     OpKernelContext::Params p;
957     DummyDevice dummy_device(nullptr);
958     p.device = &dummy_device;
959     AllocatorAttributes alloc_attrs;
960     p.output_attr_array = &alloc_attrs;
961 
962     Tensor t(123.0f);
963 
964     gtl::InlinedVector<TensorValue, 4> inputs;
965     // GetFakeKernel requires a NodeDef with two inputs
966     inputs.emplace_back(&t);
967     inputs.emplace_back();
968     p.inputs = &inputs;
969 
970     Status status;
971     std::unique_ptr<OpKernel> kernel =
972         GetFakeKernel(device_name, op_name, node_name, &status);
973     TF_EXPECT_OK(status);
974     ASSERT_NE(nullptr, kernel.get());
975 
976     p.op_kernel = kernel.get();
977     OpKernelContext ctx(&p);
978     kernel->Compute(&ctx);
979     ASSERT_EQ(123, ctx.mutable_output(0)->scalar<float>()());
980   }
981 }
982 
validate_tensor(TF_Tensor * tensor,int64_t * dims,int64_t num_dims,TF_DataType dtype)983 void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims,
984                      TF_DataType dtype) {
985   EXPECT_EQ(TF_FLOAT, TF_TensorType(tensor));
986   EXPECT_EQ(num_dims, TF_NumDims(tensor));
987   for (int i = 0; i < num_dims; ++i) {
988     EXPECT_EQ(dims[i], TF_Dim(tensor, i));
989   }
990 }
991 
992 template <typename T>
set_tensor_data(TF_Tensor * tensor,T * values,size_t tensor_size_bytes,TF_OpKernelContext * ctx)993 void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
994                      TF_OpKernelContext* ctx) {
995   T* data = reinterpret_cast<T*>(TF_TensorData(tensor));
996 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
997   OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
998   cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, values,
999                                                 tensor_size_bytes);
1000 #else
1001   memcpy(data, values, tensor_size_bytes);
1002 #endif
1003 }
1004 }  // namespace tensorflow
1005