1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <vector>
16
17 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
18 #include "tensorflow/lite/toco/model.h"
19 #include "tensorflow/lite/toco/tooling_util.h"
20 #include "tensorflow/core/platform/logging.h"
21
22 namespace toco {
23
24 namespace {
25
26 template <ArrayDataType Type>
Slice(SliceOperator const & op,Array const & input_array,Array * output_array)27 bool Slice(SliceOperator const& op, Array const& input_array,
28 Array* output_array) {
29 // Implementation is taken from the tflite kernel.
30
31 CHECK(input_array.data_type == Type);
32 CHECK(output_array->data_type == Type);
33 const auto& input_data = input_array.GetBuffer<Type>().data;
34
35 // Create a buffer for the output array.
36 std::vector<DataType<Type>>& output_data =
37 output_array->GetMutableBuffer<Type>().data;
38 output_data.resize(RequiredBufferSizeForShape(output_array->shape()));
39
40 std::vector<int> size = op.size;
41 if (size.size() != op.begin.size()) {
42 // Broadcast the end positions.
43 CHECK_EQ(op.size.size(), 1);
44 int broadcast_size = size[0];
45 while (size.size() < op.begin.size()) size.push_back(broadcast_size);
46 }
47
48 // Calculate begin and end indices along each dimension.
49 CHECK_LE(op.begin.size(), 4);
50 CHECK_LE(size.size(), 4);
51 std::vector<int> begin = op.begin;
52 std::vector<int> end;
53 for (size_t i = 0; i < begin.size(); ++i) {
54 int dim_size = size[i];
55 if (dim_size == -1) {
56 // -1 means the rest of the dimension.
57 dim_size = input_array.shape().dims()[i] - begin[i];
58 }
59 CHECK_GE(dim_size, 1);
60 end.push_back(begin[i] + dim_size - 1);
61 }
62
63 // Pad out so that we always have 4 dims, makes this loop easier.
64 while (begin.size() < 4) begin.insert(begin.begin(), 0);
65 while (end.size() < 4) end.insert(end.begin(), 0);
66 Shape padded_shape = input_array.shape();
67 while (padded_shape.dimensions_count() < 4) {
68 padded_shape.mutable_dims()->insert(padded_shape.mutable_dims()->begin(),
69 1);
70 }
71
72 auto* out_ptr = output_data.data();
73 for (int in_b = begin[0]; in_b <= end[0]; ++in_b) {
74 for (int in_h = begin[1]; in_h <= end[1]; ++in_h) {
75 for (int in_w = begin[2]; in_w <= end[2]; ++in_w) {
76 for (int in_d = begin[3]; in_d <= end[3]; ++in_d) {
77 *out_ptr++ =
78 input_data[Offset(padded_shape, {in_b, in_h, in_w, in_d})];
79 }
80 }
81 }
82 }
83
84 return true;
85 }
86
87 } // namespace
88
Run(Model * model,std::size_t op_index,bool * modified)89 ::tensorflow::Status ResolveConstantSlice::Run(Model* model,
90 std::size_t op_index,
91 bool* modified) {
92 *modified = false;
93 const auto it = model->operators.begin() + op_index;
94 const auto* base_op = it->get();
95 if (base_op->type != OperatorType::kSlice) {
96 return ::tensorflow::OkStatus();
97 }
98
99 const SliceOperator* op = static_cast<const SliceOperator*>(base_op);
100
101 CHECK_EQ(op->outputs.size(), 1);
102 auto& output_array = model->GetArray(op->outputs[0]);
103 if (output_array.data_type == ArrayDataType::kNone) {
104 // Yield until the output type has been set by PropagateArrayDataTypes.
105 return ::tensorflow::OkStatus();
106 }
107
108 if (!output_array.has_shape()) {
109 // Yield until the output shape has been set by PropagateFixedShapes.
110 return ::tensorflow::OkStatus();
111 }
112
113 if (op->begin.empty() || op->size.empty()) {
114 // Attributes have not resolved yet.
115 return ::tensorflow::OkStatus();
116 }
117
118 const auto& input_array = model->GetArray(op->inputs[0]);
119 if (!input_array.has_shape()) {
120 // Yield until the value shape has been resolved.
121 return ::tensorflow::OkStatus();
122 }
123 if (!IsConstantParameterArray(*model, op->inputs[0])) {
124 // Yield until the value is constant.
125 return ::tensorflow::OkStatus();
126 }
127
128 CHECK(!output_array.buffer);
129 switch (output_array.data_type) {
130 case ArrayDataType::kFloat:
131 if (!Slice<ArrayDataType::kFloat>(*op, input_array, &output_array)) {
132 return ::tensorflow::OkStatus();
133 }
134 break;
135 case ArrayDataType::kUint8:
136 if (!Slice<ArrayDataType::kUint8>(*op, input_array, &output_array)) {
137 return ::tensorflow::OkStatus();
138 }
139 break;
140 case ArrayDataType::kInt32:
141 if (!Slice<ArrayDataType::kInt32>(*op, input_array, &output_array)) {
142 return ::tensorflow::OkStatus();
143 }
144 break;
145 case ArrayDataType::kInt64:
146 if (!Slice<ArrayDataType::kInt64>(*op, input_array, &output_array)) {
147 return ::tensorflow::OkStatus();
148 }
149 break;
150 case ArrayDataType::kComplex64:
151 if (!Slice<ArrayDataType::kComplex64>(*op, input_array, &output_array)) {
152 return ::tensorflow::OkStatus();
153 }
154 break;
155 default:
156 LOG(FATAL) << "Unsupported data type input to Slice op with output \""
157 << op->outputs[0] << "\"";
158 break;
159 }
160
161 DeleteOpAndArrays(model, op);
162 *modified = true;
163 return ::tensorflow::OkStatus();
164 }
165
166 } // namespace toco
167