1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16
17 #include "tensorflow/lite/delegates/hexagon/builders/op_builder.h"
18
19 namespace tflite {
20 namespace delegates {
21 namespace hexagon {
22
23 // Adds Rsqrt op to the Hexagon graph by constructing
24 // 1/Sqrt(input).
25 class RsqrtOpBuilder : public OpBuilder {
26 public:
RsqrtOpBuilder(GraphBuilder * graph_builder,int op_type)27 explicit RsqrtOpBuilder(GraphBuilder* graph_builder, int op_type)
28 : OpBuilder(graph_builder, op_type) {}
29 TfLiteStatus PopulateSubGraph(const TfLiteIntArray* inputs,
30 const TfLiteIntArray* outputs,
31 TfLiteContext* context) override;
32
33 TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs,
34 TfLiteContext* context) override;
35
36 private:
37 void AddNumerator();
38
39 TensorID node_output_;
40 TensorID numerator_out_;
41 TensorID numerator_min_;
42 TensorID numerator_max_;
43 // Total number of elements in the input tensor.
44 int num_elements_;
45 };
46
AddNumerator()47 void RsqrtOpBuilder::AddNumerator() {
48 // Numerator is a constant with value 1. We add it as float and quantize it.
49 std::vector<uint8_t> numerator;
50 // Hexagon NN Div implementation assumes output to be of shape as first
51 // input, so it doesn't broadcast.
52 // So here we create the constant numerator with value 1 to be of same
53 // flattened shape as the denominator.
54 numerator.resize(num_elements_);
55 int flat_shape[] = {1, 1, 1, num_elements_};
56 std::fill(numerator.begin(), numerator.end(), 0);
57 float kNumeratorMin = 1.0, kNumeratorMax = 1.0;
58 auto* const_node = graph_builder_->AddConstNodeWithData(
59 flat_shape, reinterpret_cast<char*>(numerator.data()),
60 sizeof(numerator[0]) * numerator.size());
61 auto* numerator_min_const = graph_builder_->AddConstNodeWithData(
62 kScalarShape, reinterpret_cast<char*>(&kNumeratorMin),
63 sizeof(kNumeratorMin));
64 auto* numerator_max_const = graph_builder_->AddConstNodeWithData(
65 kScalarShape, reinterpret_cast<char*>(&kNumeratorMax),
66 sizeof(kNumeratorMax));
67 numerator_out_ = TensorID(const_node->GetID(), 0);
68 numerator_min_ = TensorID(numerator_min_const->GetID(), 0);
69 numerator_max_ = TensorID(numerator_max_const->GetID(), 0);
70 }
71
RegisterOutputs(const TfLiteIntArray * outputs,TfLiteContext * context)72 TfLiteStatus RsqrtOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs,
73 TfLiteContext* context) {
74 graph_builder_->AddTensorWithID(outputs->data[0], node_output_.first,
75 node_output_.second);
76 return kTfLiteOk;
77 }
78
PopulateSubGraph(const TfLiteIntArray * inputs,const TfLiteIntArray * outputs,TfLiteContext * context)79 TfLiteStatus RsqrtOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
80 const TfLiteIntArray* outputs,
81 TfLiteContext* context) {
82 const int tensor_id = inputs->data[0];
83 const auto& tensor = context->tensors[tensor_id];
84 float min_value = 0;
85 float max_value = 0;
86 int batch_size, height_size, width_size, depth_size;
87 GetDims(&batch_size, &height_size, &width_size, &depth_size, tensor.dims);
88 TF_LITE_ENSURE_STATUS(
89 ComputeMinAndMaxQuantValues(tensor, &min_value, &max_value));
90 num_elements_ = batch_size * height_size * width_size * depth_size;
91 int flat_shape[] = {1, 1, 1, num_elements_};
92
93 auto* min_const = graph_builder_->AddConstNodeWithData(
94 kScalarShape, reinterpret_cast<char*>(&min_value), sizeof(min_value));
95 auto* max_const = graph_builder_->AddConstNodeWithData(
96 kScalarShape, reinterpret_cast<char*>(&max_value), sizeof(max_value));
97 // Create SQRT op as denominator.
98 AddInput(graph_builder_->GetHexagonTensorId(tensor_id));
99 AddInput(TensorID(min_const->GetID(), 0));
100 AddInput(TensorID(max_const->GetID(), 0));
101 auto sqrt_output = AddOutput(
102 sizeof(uint8_t), 4, {batch_size, height_size, width_size, depth_size});
103 auto sqrt_output_min = AddOutput(sizeof(float), 4, kScalarShape);
104 auto sqrt_output_max = AddOutput(sizeof(float), 4, kScalarShape);
105
106 // Reshape result of Sqrt to be [1,1,1,NumElements] since Hexagon Div
107 // has limitation on the shape of the tensor.
108 const int reshape_shape[] = {1, 1, 1, 4};
109 auto* target_shape_node = graph_builder_->AddConstNodeWithData(
110 reshape_shape, reinterpret_cast<char*>(flat_shape),
111 sizeof(flat_shape[0]) * 4);
112 auto* reshape_op = graph_builder_->AddNode(GetTFLiteNodeID());
113 reshape_op->SetOpType(OP_Reshape);
114 reshape_op->AddInput(sqrt_output);
115 reshape_op->AddInput(TensorID(target_shape_node->GetID(), 0));
116 auto reshape_out = reshape_op->AddOutput(sizeof(uint8_t), 4, flat_shape);
117
118 // Create the numerator and add to the graph.
119 AddNumerator();
120
121 // Fetch output details
122 float output_min = -1, output_max = 1;
123 // Output details.
124 TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues(
125 context->tensors[outputs->data[0]], &output_min, &output_max));
126 auto* output_min_const = graph_builder_->AddConstNodeWithData(
127 kScalarShape, reinterpret_cast<char*>(&output_min), sizeof(output_min));
128 auto* output_max_const = graph_builder_->AddConstNodeWithData(
129 kScalarShape, reinterpret_cast<char*>(&output_max), sizeof(output_max));
130 int output_batch_size, output_height_size, output_width_size,
131 output_depth_size;
132 GetDims(&output_batch_size, &output_height_size, &output_width_size,
133 &output_depth_size, context->tensors[outputs->data[0]].dims);
134
135 // Add Div op to compute 1/Sqrt
136 auto* div_op = graph_builder_->AddNode(GetTFLiteNodeID());
137 div_op->SetOpType(OP_QuantizedDiv_8);
138 div_op->AddInput(numerator_out_);
139 div_op->AddInput(reshape_out);
140 div_op->AddInput(numerator_min_);
141 div_op->AddInput(numerator_max_);
142 div_op->AddInput(sqrt_output_min);
143 div_op->AddInput(sqrt_output_max);
144 div_op->AddInput(TensorID(output_min_const->GetID(), 0));
145 div_op->AddInput(TensorID(output_max_const->GetID(), 0));
146
147 auto div_output = div_op->AddOutput(sizeof(uint8_t), 4, flat_shape);
148 div_op->AddOutput(sizeof(float), 4, kScalarShape);
149 div_op->AddOutput(sizeof(float), 4, kScalarShape);
150
151 // Reshape output back to the expected shape.
152 int output_shape[] = {output_batch_size, output_height_size,
153 output_width_size, output_depth_size};
154 target_shape_node = graph_builder_->AddConstNodeWithData(
155 reshape_shape, reinterpret_cast<char*>(output_shape),
156 sizeof(output_shape[0]) * 4);
157
158 reshape_op = graph_builder_->AddNode(GetTFLiteNodeID());
159 reshape_op->SetOpType(OP_Reshape);
160 reshape_op->AddInput(div_output);
161 reshape_op->AddInput(TensorID(target_shape_node->GetID(), 0));
162 node_output_ = reshape_op->AddOutput(sizeof(uint8_t), 4, output_shape);
163 return kTfLiteOk;
164 }
165
CreateRSqrtOpBuilder(GraphBuilder * graph_builder,int op_type)166 OpBuilder* CreateRSqrtOpBuilder(GraphBuilder* graph_builder, int op_type) {
167 return new RsqrtOpBuilder(graph_builder, op_type);
168 }
169
170 } // namespace hexagon
171 } // namespace delegates
172 } // namespace tflite
173