• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstdint>
16 
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19 #include "tensorflow/lite/kernels/custom_ops_register.h"
20 #include "tensorflow/lite/kernels/test_util.h"
21 #include "tensorflow/lite/schema/schema_generated.h"
22 #include "tensorflow/lite/testing/util.h"
23 
24 namespace tflite {
25 namespace {
26 template <typename T>
27 tflite::TensorType GetTTEnum();
28 
29 template <>
GetTTEnum()30 tflite::TensorType GetTTEnum<float>() {
31   return tflite::TensorType_FLOAT32;
32 }
33 
34 template <>
GetTTEnum()35 tflite::TensorType GetTTEnum<double>() {
36   return tflite::TensorType_FLOAT64;
37 }
38 
39 template <>
GetTTEnum()40 tflite::TensorType GetTTEnum<int8_t>() {
41   return tflite::TensorType_INT8;
42 }
43 
44 template <>
GetTTEnum()45 tflite::TensorType GetTTEnum<int32_t>() {
46   return tflite::TensorType_INT32;
47 }
48 
49 template <>
GetTTEnum()50 tflite::TensorType GetTTEnum<int64_t>() {
51   return tflite::TensorType_INT64;
52 }
53 
54 class RandomUniformOpModel : public tflite::SingleOpModel {
55  public:
RandomUniformOpModel(const std::initializer_list<int> & input,tflite::TensorData output,bool dynamic_input)56   RandomUniformOpModel(const std::initializer_list<int>& input,
57                        tflite::TensorData output, bool dynamic_input) {
58     if (dynamic_input) {
59       input_ = AddInput({tflite::TensorType_INT32, {3}});
60     } else {
61       input_ = AddConstInput(tflite::TensorType_INT32, input,
62                              {static_cast<int>(input.size())});
63     }
64     output_ = AddOutput(output);
65     SetCustomOp("RandomUniform", {}, ops::custom::Register_RANDOM_UNIFORM);
66     BuildInterpreter({GetShape(input_)});
67     if (dynamic_input) {
68       PopulateTensor<int32_t>(input_, std::vector<int32_t>(input));
69     }
70   }
71 
72   int input_;
73   int output_;
74 
input()75   int input() { return input_; }
output()76   int output() { return output_; }
77 
78   template <typename T>
GetOutput()79   std::vector<T> GetOutput() {
80     return ExtractVector<T>(output_);
81   }
82 };
83 
84 class RandomUniformIntOpModel : public tflite::SingleOpModel {
85  public:
RandomUniformIntOpModel(const std::initializer_list<int> & input,tflite::TensorData output,int min_val,int max_val)86   RandomUniformIntOpModel(const std::initializer_list<int>& input,
87                           tflite::TensorData output, int min_val, int max_val) {
88     input_ = AddConstInput(tflite::TensorType_INT32, input,
89                            {static_cast<int>(input.size())});
90     input_minval_ = AddConstInput(tflite::TensorType_INT32, {min_val}, {1});
91     input_maxval_ = AddConstInput(tflite::TensorType_INT32, {max_val}, {1});
92     output_ = AddOutput(output);
93     SetCustomOp("RandomUniformInt", {},
94                 ops::custom::Register_RANDOM_UNIFORM_INT);
95     BuildInterpreter({GetShape(input_)});
96   }
97 
98   int input_;
99   int input_minval_;
100   int input_maxval_;
101 
102   int output_;
103 
input()104   int input() { return input_; }
output()105   int output() { return output_; }
106 
107   template <typename T>
GetOutput()108   std::vector<T> GetOutput() {
109     return ExtractVector<T>(output_);
110   }
111 };
112 
113 }  // namespace
114 }  // namespace tflite
115 
116 template <typename FloatType>
117 class RandomUniformTest : public ::testing::Test {
118  public:
119   using Float = FloatType;
120 };
121 
122 using TestTypes = ::testing::Types<float, double>;
123 
124 TYPED_TEST_SUITE(RandomUniformTest, TestTypes);
125 
TYPED_TEST(RandomUniformTest,TestOutput)126 TYPED_TEST(RandomUniformTest, TestOutput) {
127   using Float = typename TestFixture::Float;
128   for (const auto dynamic : {true, false}) {
129     tflite::RandomUniformOpModel m({1000, 50, 5},
130                                    {tflite::GetTTEnum<Float>(), {}}, dynamic);
131     m.Invoke();
132     auto output = m.GetOutput<Float>();
133     EXPECT_EQ(output.size(), 1000 * 50 * 5);
134 
135     double sum = 0;
136     for (const auto r : output) {
137       sum += r;
138     }
139     double avg = sum / output.size();
140     ASSERT_LT(std::abs(avg - 0.5), 0.05);  // Average should approximately 0.5
141 
142     double sum_squared = 0;
143     for (const auto r : output) {
144       sum_squared += std::pow(r - avg, 2);
145     }
146     double var = sum_squared / output.size();
147     EXPECT_LT(std::abs(1. / 12 - var),
148               0.05);  // Variance should be approximately 1./12
149   }
150 }
151 
152 template <typename IntType>
153 class RandomUniformIntTest : public ::testing::Test {
154  public:
155   using Int = IntType;
156 };
157 
158 using TestTypesInt = ::testing::Types<int8_t, int32_t, int64_t>;
159 
160 TYPED_TEST_SUITE(RandomUniformIntTest, TestTypesInt);
161 
TYPED_TEST(RandomUniformIntTest,TestOutput)162 TYPED_TEST(RandomUniformIntTest, TestOutput) {
163   using Int = typename TestFixture::Int;
164   tflite::RandomUniformIntOpModel m({1000, 50, 5},
165                                     {tflite::GetTTEnum<Int>(), {}}, 0, 5);
166   m.Invoke();
167   auto output = m.GetOutput<Int>();
168   EXPECT_EQ(output.size(), 1000 * 50 * 5);
169 
170   int counters[] = {0, 0, 0, 0, 0, 0};
171   for (const auto r : output) {
172     ASSERT_GE(r, 0);
173     ASSERT_LE(r, 5);
174     ++counters[r];
175   }
176   // Check that all numbers are meet with near the same frequency.
177   for (int i = 1; i < 6; ++i) {
178     EXPECT_LT(std::abs(counters[i] - counters[0]), 1000);
179   }
180 }
181