• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <vector>
16 
17 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
18 #include "tensorflow/lite/toco/model.h"
19 #include "tensorflow/lite/toco/tooling_util.h"
20 #include "tensorflow/core/platform/logging.h"
21 
22 namespace toco {
23 
24 namespace {
25 
26 template <ArrayDataType Type>
Slice(SliceOperator const & op,Array const & input_array,Array * output_array)27 bool Slice(SliceOperator const& op, Array const& input_array,
28            Array* output_array) {
29   // Implementation is taken from the tflite kernel.
30 
31   CHECK(input_array.data_type == Type);
32   CHECK(output_array->data_type == Type);
33   const auto& input_data = input_array.GetBuffer<Type>().data;
34 
35   // Create a buffer for the output array.
36   std::vector<DataType<Type>>& output_data =
37       output_array->GetMutableBuffer<Type>().data;
38   output_data.resize(RequiredBufferSizeForShape(output_array->shape()));
39 
40   std::vector<int> size = op.size;
41   if (size.size() != op.begin.size()) {
42     // Broadcast the end positions.
43     CHECK_EQ(op.size.size(), 1);
44     int broadcast_size = size[0];
45     while (size.size() < op.begin.size()) size.push_back(broadcast_size);
46   }
47 
48   // Calculate begin and end indices along each dimension.
49   CHECK_LE(op.begin.size(), 4);
50   CHECK_LE(size.size(), 4);
51   std::vector<int> begin = op.begin;
52   std::vector<int> end;
53   for (size_t i = 0; i < begin.size(); ++i) {
54     int dim_size = size[i];
55     if (dim_size == -1) {
56       // -1 means the rest of the dimension.
57       dim_size = input_array.shape().dims()[i] - begin[i];
58     }
59     CHECK_GE(dim_size, 1);
60     end.push_back(begin[i] + dim_size - 1);
61   }
62 
63   // Pad out so that we always have 4 dims, makes this loop easier.
64   while (begin.size() < 4) begin.insert(begin.begin(), 0);
65   while (end.size() < 4) end.insert(end.begin(), 0);
66   Shape padded_shape = input_array.shape();
67   while (padded_shape.dimensions_count() < 4) {
68     padded_shape.mutable_dims()->insert(padded_shape.mutable_dims()->begin(),
69                                         1);
70   }
71 
72   auto* out_ptr = output_data.data();
73   for (int in_b = begin[0]; in_b <= end[0]; ++in_b) {
74     for (int in_h = begin[1]; in_h <= end[1]; ++in_h) {
75       for (int in_w = begin[2]; in_w <= end[2]; ++in_w) {
76         for (int in_d = begin[3]; in_d <= end[3]; ++in_d) {
77           *out_ptr++ =
78               input_data[Offset(padded_shape, {in_b, in_h, in_w, in_d})];
79         }
80       }
81     }
82   }
83 
84   return true;
85 }
86 
87 }  // namespace
88 
Run(Model * model,std::size_t op_index,bool * modified)89 ::tensorflow::Status ResolveConstantSlice::Run(Model* model,
90                                                std::size_t op_index,
91                                                bool* modified) {
92   *modified = false;
93   const auto it = model->operators.begin() + op_index;
94   const auto* base_op = it->get();
95   if (base_op->type != OperatorType::kSlice) {
96     return ::tensorflow::OkStatus();
97   }
98 
99   const SliceOperator* op = static_cast<const SliceOperator*>(base_op);
100 
101   CHECK_EQ(op->outputs.size(), 1);
102   auto& output_array = model->GetArray(op->outputs[0]);
103   if (output_array.data_type == ArrayDataType::kNone) {
104     // Yield until the output type has been set by PropagateArrayDataTypes.
105     return ::tensorflow::OkStatus();
106   }
107 
108   if (!output_array.has_shape()) {
109     // Yield until the output shape has been set by PropagateFixedShapes.
110     return ::tensorflow::OkStatus();
111   }
112 
113   if (op->begin.empty() || op->size.empty()) {
114     // Attributes have not resolved yet.
115     return ::tensorflow::OkStatus();
116   }
117 
118   const auto& input_array = model->GetArray(op->inputs[0]);
119   if (!input_array.has_shape()) {
120     // Yield until the value shape has been resolved.
121     return ::tensorflow::OkStatus();
122   }
123   if (!IsConstantParameterArray(*model, op->inputs[0])) {
124     // Yield until the value is constant.
125     return ::tensorflow::OkStatus();
126   }
127 
128   CHECK(!output_array.buffer);
129   switch (output_array.data_type) {
130     case ArrayDataType::kFloat:
131       if (!Slice<ArrayDataType::kFloat>(*op, input_array, &output_array)) {
132         return ::tensorflow::OkStatus();
133       }
134       break;
135     case ArrayDataType::kUint8:
136       if (!Slice<ArrayDataType::kUint8>(*op, input_array, &output_array)) {
137         return ::tensorflow::OkStatus();
138       }
139       break;
140     case ArrayDataType::kInt32:
141       if (!Slice<ArrayDataType::kInt32>(*op, input_array, &output_array)) {
142         return ::tensorflow::OkStatus();
143       }
144       break;
145     case ArrayDataType::kInt64:
146       if (!Slice<ArrayDataType::kInt64>(*op, input_array, &output_array)) {
147         return ::tensorflow::OkStatus();
148       }
149       break;
150     case ArrayDataType::kComplex64:
151       if (!Slice<ArrayDataType::kComplex64>(*op, input_array, &output_array)) {
152         return ::tensorflow::OkStatus();
153       }
154       break;
155     default:
156       LOG(FATAL) << "Unsupported data type input to Slice op with output \""
157                  << op->outputs[0] << "\"";
158       break;
159   }
160 
161   DeleteOpAndArrays(model, op);
162   *modified = true;
163   return ::tensorflow::OkStatus();
164 }
165 
166 }  // namespace toco
167