1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cmath>
17 #include <cstdint>
18 #include <limits>
19 #include <random>
20
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24
25 namespace tflite {
26 namespace ops {
27 namespace custom {
28 namespace random_standard_normal {
29
30 struct OpData {
31 std::default_random_engine rng;
32 };
33
34 // Draws a sample from standard normal distribution.
35 template <typename Float>
RandomStandardNormalSample(std::default_random_engine & rng,Float * output,size_t output_size)36 TfLiteStatus RandomStandardNormalSample(std::default_random_engine& rng,
37 Float* output, size_t output_size) {
38 std::normal_distribution<Float> dist;
39 for (Float* it = output; it != output + output_size; ++it) {
40 *it = dist(rng);
41 }
42 return kTfLiteOk;
43 }
44
RandomStandardNormalSample(TfLiteContext * context,std::default_random_engine & rng,TfLiteTensor * output,size_t output_size)45 TfLiteStatus RandomStandardNormalSample(TfLiteContext* context,
46 std::default_random_engine& rng,
47 TfLiteTensor* output,
48 size_t output_size) {
49 switch (output->type) {
50 case kTfLiteFloat32:
51 TF_LITE_ENSURE_OK(context,
52 RandomStandardNormalSample<float>(
53 rng, GetTensorData<float>(output), output_size));
54 break;
55 case kTfLiteFloat64:
56 TF_LITE_ENSURE_OK(context,
57 RandomStandardNormalSample<double>(
58 rng, GetTensorData<double>(output), output_size));
59 break;
60 default:
61 TF_LITE_KERNEL_LOG(
62 context, "Unsupported output datatype for RandomStandardNormal: %s",
63 TfLiteTypeGetName(output->type));
64 return kTfLiteError;
65 }
66 return kTfLiteOk;
67 }
68
Init(TfLiteContext * context,const char * buffer,size_t length)69 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
70 return new OpData();
71 }
72
Free(TfLiteContext * context,void * buffer)73 void Free(TfLiteContext* context, void* buffer) {
74 delete reinterpret_cast<OpData*>(buffer);
75 }
76
Prepare(TfLiteContext * context,TfLiteNode * node)77 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
78 // TODO(b/169611265): Handle optional seed input.
79 TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 1);
80 TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
81
82 // Input is a shape tensor.
83 const TfLiteTensor* input = tflite::GetInput(context, node, 0);
84 TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 1);
85 // TODO(b/169611265): Support dynamic output tensors.
86 TF_LITE_ENSURE(context, IsConstantTensor(input));
87
88 // TODO(b/169611265): Handle other input data types.
89 TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt32);
90
91 int output_dims = tflite::SizeOfDimension(input, 0);
92 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_dims);
93 for (int i = 0; i < output_dims; i++) {
94 output_shape->data[i] = input->data.i32[i];
95 }
96
97 TfLiteTensor* output = tflite::GetOutput(context, node, 0);
98 // ResizeTensor takes ownership of output_shape.
99 return context->ResizeTensor(context, output, output_shape);
100 }
101
Eval(TfLiteContext * context,TfLiteNode * node)102 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
103 // TODO(b/169611265): Handle optional seed input.
104 OpData* params = reinterpret_cast<OpData*>(node->user_data);
105 TF_LITE_ENSURE(context, params != nullptr);
106
107 TfLiteTensor* output = tflite::GetOutput(context, node, 0);
108 size_t output_size = tflite::NumElements(output);
109
110 TF_LITE_ENSURE_OK(context, RandomStandardNormalSample(context, params->rng,
111 output, output_size));
112
113 return kTfLiteOk;
114 }
115
116 } // namespace random_standard_normal
117
Register_RANDOM_STANDARD_NORMAL()118 TfLiteRegistration* Register_RANDOM_STANDARD_NORMAL() {
119 static TfLiteRegistration r = {
120 random_standard_normal::Init, random_standard_normal::Free,
121 random_standard_normal::Prepare, random_standard_normal::Eval};
122 return &r;
123 }
124
125 } // namespace custom
126 } // namespace ops
127 } // namespace tflite
128