• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include "tensorflow/lite/c/c_api_types.h"
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/internal/types.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace where {
29 
30 constexpr int kInputConditionTensor = 0;
31 constexpr int kOutputTensor = 0;
32 
33 template <typename T>
ResizeOutputTensor(TfLiteContext * context,const TfLiteTensor * cond_tensor,TfLiteTensor * output_tensor)34 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
35                                 const TfLiteTensor* cond_tensor,
36                                 TfLiteTensor* output_tensor) {
37   // Output tensor should have shape:
38   // (num_true, cond_rank), where num_true denotes the number of true values
39   // in condition.
40   const RuntimeShape& cond_shape = GetTensorShape(cond_tensor);
41   const int size = cond_shape.FlatSize();
42   const int cond_rank = cond_shape.DimensionsCount();
43   const T* cond_data = GetTensorData<T>(cond_tensor);
44 
45   int true_count = 0;
46   for (int i = 0; i < size; ++i) {
47     if (cond_data[i] != T(0)) {
48       true_count++;
49     }
50   }
51   TfLiteIntArray* output_dims = TfLiteIntArrayCreate(2);
52   output_dims->data[0] = true_count;
53   output_dims->data[1] = cond_rank;
54   return context->ResizeTensor(context, output_tensor, output_dims);
55 }
56 
57 template <typename T>
PrepareOutput(TfLiteContext * context,const TfLiteTensor * cond_tensor,TfLiteTensor * output)58 TfLiteStatus PrepareOutput(TfLiteContext* context,
59                            const TfLiteTensor* cond_tensor,
60                            TfLiteTensor* output) {
61   // As output will be a 2D tensor of indices, use int64 to be consistent with
62   // tensorflow.
63   output->type = kTfLiteInt64;
64 
65   // Exit early if cond is a non-const tensor. Set output tensor to dynamic so
66   // output size can be determined in Eval.
67   if (!IsConstantTensor(cond_tensor)) {
68     SetTensorToDynamic(output);
69     return kTfLiteOk;
70   }
71   return ResizeOutputTensor<T>(context, cond_tensor, output);
72 }
73 
Prepare(TfLiteContext * context,TfLiteNode * node)74 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
75   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
76   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
77 
78   const TfLiteTensor* cond_tensor;
79   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputConditionTensor,
80                                           &cond_tensor));
81   TfLiteTensor* output;
82   TF_LITE_ENSURE_OK(context,
83                     GetOutputSafe(context, node, kOutputTensor, &output));
84 
85   switch (cond_tensor->type) {
86     case kTfLiteBool:
87       return PrepareOutput<bool>(context, cond_tensor, output);
88     case kTfLiteFloat32:
89       return PrepareOutput<float>(context, cond_tensor, output);
90     case kTfLiteInt64:
91       return PrepareOutput<int64_t>(context, cond_tensor, output);
92     case kTfLiteInt32:
93       return PrepareOutput<int32_t>(context, cond_tensor, output);
94     case kTfLiteInt8:
95       return PrepareOutput<int8_t>(context, cond_tensor, output);
96     case kTfLiteUInt8:
97       return PrepareOutput<uint8_t>(context, cond_tensor, output);
98     case kTfLiteUInt32:
99       return PrepareOutput<uint32_t>(context, cond_tensor, output);
100     default:
101       TF_LITE_KERNEL_LOG(context,
102                          "Condition tensor has unsupported type: '%s'.",
103                          TfLiteTypeGetName(cond_tensor->type));
104   }
105   return kTfLiteOk;
106 }
107 
Eval(TfLiteContext * context,TfLiteNode * node)108 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
109   const TfLiteTensor* cond_tensor;
110   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputConditionTensor,
111                                           &cond_tensor));
112   TfLiteTensor* output;
113   TF_LITE_ENSURE_OK(context,
114                     GetOutputSafe(context, node, kOutputTensor, &output));
115 
116   if (IsDynamicTensor(output)) {
117     switch (cond_tensor->type) {
118       case kTfLiteBool:
119         TF_LITE_ENSURE_OK(
120             context, ResizeOutputTensor<bool>(context, cond_tensor, output));
121         break;
122       case kTfLiteFloat32:
123         TF_LITE_ENSURE_OK(
124             context, ResizeOutputTensor<float>(context, cond_tensor, output));
125         break;
126       case kTfLiteInt64:
127         TF_LITE_ENSURE_OK(
128             context, ResizeOutputTensor<int64_t>(context, cond_tensor, output));
129         break;
130       case kTfLiteInt32:
131         TF_LITE_ENSURE_OK(
132             context, ResizeOutputTensor<int32_t>(context, cond_tensor, output));
133         break;
134       case kTfLiteInt8:
135         TF_LITE_ENSURE_OK(
136             context, ResizeOutputTensor<int8_t>(context, cond_tensor, output));
137         break;
138       case kTfLiteUInt8:
139         TF_LITE_ENSURE_OK(
140             context, ResizeOutputTensor<uint8_t>(context, cond_tensor, output));
141         break;
142       case kTfLiteUInt32:
143         TF_LITE_ENSURE_OK(context, ResizeOutputTensor<uint32_t>(
144                                        context, cond_tensor, output));
145         break;
146       default:
147         TF_LITE_KERNEL_LOG(context,
148                            "Condition tensor has unsupported type: '%s'.",
149                            TfLiteTypeGetName(cond_tensor->type));
150     }
151   }
152 
153   TfLiteIntArray* dims = cond_tensor->dims;
154   if (dims->size == 0) {
155     // Scalar tensors are not supported.
156     TF_LITE_KERNEL_LOG(context, "Where op requires condition w/ rank > 0");
157     return kTfLiteError;
158   }
159 
160   switch (cond_tensor->type) {
161     case kTfLiteBool:
162       reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
163                                       GetTensorData<bool>(cond_tensor),
164                                       GetTensorData<int64_t>(output));
165       break;
166     case kTfLiteFloat32:
167       reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
168                                       GetTensorData<float>(cond_tensor),
169                                       GetTensorData<int64_t>(output));
170       break;
171     case kTfLiteInt64:
172       reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
173                                       GetTensorData<int64_t>(cond_tensor),
174                                       GetTensorData<int64_t>(output));
175       break;
176     case kTfLiteInt32:
177       reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
178                                       GetTensorData<int32_t>(cond_tensor),
179                                       GetTensorData<int64_t>(output));
180       break;
181     case kTfLiteInt8:
182       reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
183                                       GetTensorData<int8_t>(cond_tensor),
184                                       GetTensorData<int64_t>(output));
185       break;
186     case kTfLiteUInt8:
187       reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
188                                       GetTensorData<uint8_t>(cond_tensor),
189                                       GetTensorData<int64_t>(output));
190       break;
191     case kTfLiteUInt32:
192       reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
193                                       GetTensorData<uint32_t>(cond_tensor),
194                                       GetTensorData<int64_t>(output));
195       break;
196     default:
197       TF_LITE_KERNEL_LOG(context,
198                          "Condition tensor has unsupported type: '%s'.",
199                          TfLiteTypeGetName(cond_tensor->type));
200   }
201   return kTfLiteOk;
202 }
203 }  // namespace where
204 
Register_WHERE()205 TfLiteRegistration* Register_WHERE() {
206   static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
207                                  where::Prepare, where::Eval};
208   return &r;
209 }
210 
211 }  // namespace builtin
212 }  // namespace ops
213 }  // namespace tflite
214