1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstdint>
16
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19 #include "tensorflow/lite/kernels/custom_ops_register.h"
20 #include "tensorflow/lite/kernels/test_util.h"
21 #include "tensorflow/lite/schema/schema_generated.h"
22 #include "tensorflow/lite/testing/util.h"
23
24 namespace tflite {
25 namespace {
26 template <typename T>
27 tflite::TensorType GetTTEnum();
28
29 template <>
GetTTEnum()30 tflite::TensorType GetTTEnum<float>() {
31 return tflite::TensorType_FLOAT32;
32 }
33
34 template <>
GetTTEnum()35 tflite::TensorType GetTTEnum<double>() {
36 return tflite::TensorType_FLOAT64;
37 }
38
39 template <>
GetTTEnum()40 tflite::TensorType GetTTEnum<int8_t>() {
41 return tflite::TensorType_INT8;
42 }
43
44 template <>
GetTTEnum()45 tflite::TensorType GetTTEnum<int32_t>() {
46 return tflite::TensorType_INT32;
47 }
48
49 template <>
GetTTEnum()50 tflite::TensorType GetTTEnum<int64_t>() {
51 return tflite::TensorType_INT64;
52 }
53
54 class RandomUniformOpModel : public tflite::SingleOpModel {
55 public:
RandomUniformOpModel(const std::initializer_list<int> & input,tflite::TensorData output,bool dynamic_input)56 RandomUniformOpModel(const std::initializer_list<int>& input,
57 tflite::TensorData output, bool dynamic_input) {
58 if (dynamic_input) {
59 input_ = AddInput({tflite::TensorType_INT32, {3}});
60 } else {
61 input_ = AddConstInput(tflite::TensorType_INT32, input,
62 {static_cast<int>(input.size())});
63 }
64 output_ = AddOutput(output);
65 SetCustomOp("RandomUniform", {}, ops::custom::Register_RANDOM_UNIFORM);
66 BuildInterpreter({GetShape(input_)});
67 if (dynamic_input) {
68 PopulateTensor<int32_t>(input_, std::vector<int32_t>(input));
69 }
70 }
71
72 int input_;
73 int output_;
74
input()75 int input() { return input_; }
output()76 int output() { return output_; }
77
78 template <typename T>
GetOutput()79 std::vector<T> GetOutput() {
80 return ExtractVector<T>(output_);
81 }
82 };
83
84 class RandomUniformIntOpModel : public tflite::SingleOpModel {
85 public:
RandomUniformIntOpModel(const std::initializer_list<int> & input,tflite::TensorData output,int min_val,int max_val)86 RandomUniformIntOpModel(const std::initializer_list<int>& input,
87 tflite::TensorData output, int min_val, int max_val) {
88 input_ = AddConstInput(tflite::TensorType_INT32, input,
89 {static_cast<int>(input.size())});
90 input_minval_ = AddConstInput(tflite::TensorType_INT32, {min_val}, {1});
91 input_maxval_ = AddConstInput(tflite::TensorType_INT32, {max_val}, {1});
92 output_ = AddOutput(output);
93 SetCustomOp("RandomUniformInt", {},
94 ops::custom::Register_RANDOM_UNIFORM_INT);
95 BuildInterpreter({GetShape(input_)});
96 }
97
98 int input_;
99 int input_minval_;
100 int input_maxval_;
101
102 int output_;
103
input()104 int input() { return input_; }
output()105 int output() { return output_; }
106
107 template <typename T>
GetOutput()108 std::vector<T> GetOutput() {
109 return ExtractVector<T>(output_);
110 }
111 };
112
113 } // namespace
114 } // namespace tflite
115
116 template <typename FloatType>
117 class RandomUniformTest : public ::testing::Test {
118 public:
119 using Float = FloatType;
120 };
121
122 using TestTypes = ::testing::Types<float, double>;
123
124 TYPED_TEST_SUITE(RandomUniformTest, TestTypes);
125
TYPED_TEST(RandomUniformTest,TestOutput)126 TYPED_TEST(RandomUniformTest, TestOutput) {
127 using Float = typename TestFixture::Float;
128 for (const auto dynamic : {true, false}) {
129 tflite::RandomUniformOpModel m({1000, 50, 5},
130 {tflite::GetTTEnum<Float>(), {}}, dynamic);
131 m.Invoke();
132 auto output = m.GetOutput<Float>();
133 EXPECT_EQ(output.size(), 1000 * 50 * 5);
134
135 double sum = 0;
136 for (const auto r : output) {
137 sum += r;
138 }
139 double avg = sum / output.size();
140 ASSERT_LT(std::abs(avg - 0.5), 0.05); // Average should approximately 0.5
141
142 double sum_squared = 0;
143 for (const auto r : output) {
144 sum_squared += std::pow(r - avg, 2);
145 }
146 double var = sum_squared / output.size();
147 EXPECT_LT(std::abs(1. / 12 - var),
148 0.05); // Variance should be approximately 1./12
149 }
150 }
151
152 template <typename IntType>
153 class RandomUniformIntTest : public ::testing::Test {
154 public:
155 using Int = IntType;
156 };
157
158 using TestTypesInt = ::testing::Types<int8_t, int32_t, int64_t>;
159
160 TYPED_TEST_SUITE(RandomUniformIntTest, TestTypesInt);
161
TYPED_TEST(RandomUniformIntTest,TestOutput)162 TYPED_TEST(RandomUniformIntTest, TestOutput) {
163 using Int = typename TestFixture::Int;
164 tflite::RandomUniformIntOpModel m({1000, 50, 5},
165 {tflite::GetTTEnum<Int>(), {}}, 0, 5);
166 m.Invoke();
167 auto output = m.GetOutput<Int>();
168 EXPECT_EQ(output.size(), 1000 * 50 * 5);
169
170 int counters[] = {0, 0, 0, 0, 0, 0};
171 for (const auto r : output) {
172 ASSERT_GE(r, 0);
173 ASSERT_LE(r, 5);
174 ++counters[r];
175 }
176 // Check that all numbers are meet with near the same frequency.
177 for (int i = 1; i < 6; ++i) {
178 EXPECT_LT(std::abs(counters[i] - counters[0]), 1000);
179 }
180 }
181