1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/python/testdata/test_registerer.h"
16
17 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
18 #include "tensorflow/lite/kernels/kernel_util.h"
19
20 namespace tflite {
21
22 namespace {
23
24 static int num_test_registerer_calls = 0;
25
GetFakeRegistration()26 TfLiteRegistration* GetFakeRegistration() {
27 static TfLiteRegistration fake_op;
28 return &fake_op;
29 }
30
31 namespace double_op {
32
33 constexpr int kInputTensor = 0;
34 constexpr int kOutputTensor = 0;
35
Prepare(TfLiteContext * context,TfLiteNode * node)36 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
37 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
38 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
39
40 const TfLiteTensor* input;
41 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
42 TfLiteTensor* output;
43 TF_LITE_ENSURE_OK(context,
44 GetOutputSafe(context, node, kOutputTensor, &output));
45
46 TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
47 TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
48
49 return context->ResizeTensor(context, output, output_shape);
50 }
51
Eval(TfLiteContext * context,TfLiteNode * node)52 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
53 const TfLiteTensor* input;
54 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
55 TfLiteTensor* output;
56 TF_LITE_ENSURE_OK(context,
57 GetOutputSafe(context, node, kOutputTensor, &output));
58
59 TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
60
61 const size_t size = GetTensorShape(input).FlatSize();
62
63 if (input->type == kTfLiteFloat32) {
64 const float* input_ptr = input->data.f;
65 float* output_ptr = output->data.f;
66 for (int i = 0; i < size; ++i) {
67 output_ptr[i] = input_ptr[i] + input_ptr[i];
68 }
69 } else if (input->type == kTfLiteInt32) {
70 const int32_t* input_ptr = input->data.i32;
71 int32_t* output_ptr = output->data.i32;
72 for (int i = 0; i < size; ++i) {
73 output_ptr[i] = input_ptr[i] + input_ptr[i];
74 }
75 } else {
76 return kTfLiteError;
77 }
78 return kTfLiteOk;
79 }
80
81 } // namespace double_op
82
GetDoubleRegistration()83 TfLiteRegistration* GetDoubleRegistration() {
84 static TfLiteRegistration double_op = {nullptr, nullptr, double_op::Prepare,
85 double_op::Eval};
86 return &double_op;
87 }
88 } // namespace
89
90 // Dummy registerer function with the correct signature. Registers a fake custom
91 // op needed by test models. Increments the num_test_registerer_calls counter by
92 // one. The TF_ prefix is needed to get past the version script in the OSS
93 // build.
TF_TestRegisterer(tflite::MutableOpResolver * resolver)94 extern "C" void TF_TestRegisterer(tflite::MutableOpResolver *resolver) {
95 resolver->AddCustom("FakeOp", GetFakeRegistration());
96 resolver->AddCustom("Double", GetDoubleRegistration());
97 num_test_registerer_calls++;
98 }
99
100 // Returns the num_test_registerer_calls counter and re-sets it.
get_num_test_registerer_calls()101 int get_num_test_registerer_calls() {
102 const int result = num_test_registerer_calls;
103 num_test_registerer_calls = 0;
104 return result;
105 }
106
107 } // namespace tflite
108