• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 #include <cstdint>
18 #include <limits>
19 #include <random>
20 
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace custom {
28 namespace random_standard_normal {
29 
30 struct OpData {
31   std::default_random_engine rng;
32 };
33 
34 // Draws a sample from standard normal distribution.
35 template <typename Float>
RandomStandardNormalSample(std::default_random_engine & rng,Float * output,size_t output_size)36 TfLiteStatus RandomStandardNormalSample(std::default_random_engine& rng,
37                                         Float* output, size_t output_size) {
38   std::normal_distribution<Float> dist;
39   for (Float* it = output; it != output + output_size; ++it) {
40     *it = dist(rng);
41   }
42   return kTfLiteOk;
43 }
44 
RandomStandardNormalSample(TfLiteContext * context,std::default_random_engine & rng,TfLiteTensor * output,size_t output_size)45 TfLiteStatus RandomStandardNormalSample(TfLiteContext* context,
46                                         std::default_random_engine& rng,
47                                         TfLiteTensor* output,
48                                         size_t output_size) {
49   switch (output->type) {
50     case kTfLiteFloat32:
51       TF_LITE_ENSURE_OK(context,
52                         RandomStandardNormalSample<float>(
53                             rng, GetTensorData<float>(output), output_size));
54       break;
55     case kTfLiteFloat64:
56       TF_LITE_ENSURE_OK(context,
57                         RandomStandardNormalSample<double>(
58                             rng, GetTensorData<double>(output), output_size));
59       break;
60     default:
61       TF_LITE_KERNEL_LOG(
62           context, "Unsupported output datatype for RandomStandardNormal: %s",
63           TfLiteTypeGetName(output->type));
64       return kTfLiteError;
65   }
66   return kTfLiteOk;
67 }
68 
Init(TfLiteContext * context,const char * buffer,size_t length)69 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
70   return new OpData();
71 }
72 
Free(TfLiteContext * context,void * buffer)73 void Free(TfLiteContext* context, void* buffer) {
74   delete reinterpret_cast<OpData*>(buffer);
75 }
76 
Prepare(TfLiteContext * context,TfLiteNode * node)77 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
78   // TODO(b/169611265): Handle optional seed input.
79   TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 1);
80   TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
81 
82   // Input is a shape tensor.
83   const TfLiteTensor* input = tflite::GetInput(context, node, 0);
84   TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 1);
85   // TODO(b/169611265): Support dynamic output tensors.
86   TF_LITE_ENSURE(context, IsConstantTensor(input));
87 
88   // TODO(b/169611265): Handle other input data types.
89   TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt32);
90 
91   int output_dims = tflite::SizeOfDimension(input, 0);
92   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_dims);
93   for (int i = 0; i < output_dims; i++) {
94     output_shape->data[i] = input->data.i32[i];
95   }
96 
97   TfLiteTensor* output = tflite::GetOutput(context, node, 0);
98   // ResizeTensor takes ownership of output_shape.
99   return context->ResizeTensor(context, output, output_shape);
100 }
101 
Eval(TfLiteContext * context,TfLiteNode * node)102 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
103   // TODO(b/169611265): Handle optional seed input.
104   OpData* params = reinterpret_cast<OpData*>(node->user_data);
105   TF_LITE_ENSURE(context, params != nullptr);
106 
107   TfLiteTensor* output = tflite::GetOutput(context, node, 0);
108   size_t output_size = tflite::NumElements(output);
109 
110   TF_LITE_ENSURE_OK(context, RandomStandardNormalSample(context, params->rng,
111                                                         output, output_size));
112 
113   return kTfLiteOk;
114 }
115 
116 }  // namespace random_standard_normal
117 
Register_RANDOM_STANDARD_NORMAL()118 TfLiteRegistration* Register_RANDOM_STANDARD_NORMAL() {
119   static TfLiteRegistration r = {
120       random_standard_normal::Init, random_standard_normal::Free,
121       random_standard_normal::Prepare, random_standard_normal::Eval};
122   return &r;
123 }
124 
125 }  // namespace custom
126 }  // namespace ops
127 }  // namespace tflite
128