1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/coreml/builders/reshape_op_builder.h"
16
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/delegates/coreml/builders/op_builder.h"
20 #include "tensorflow/lite/delegates/coreml/builders/op_factory.h"
21 #include "tensorflow/lite/delegates/coreml/builders/op_validator.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24
25 namespace tflite {
26 namespace delegates {
27 namespace coreml {
28
DebugName()29 const std::string& ReshapeOpBuilder::DebugName() {
30 if (debug_name_.empty()) {
31 SetDebugName("ReshapeOpBuilder", node_id_);
32 }
33 return debug_name_;
34 }
35
Build()36 CoreML::Specification::NeuralNetworkLayer* ReshapeOpBuilder::Build() {
37 if (layer_ == nullptr) {
38 layer_.reset(new CoreML::Specification::NeuralNetworkLayer);
39 }
40 layer_->set_name(DebugName());
41 for (int dim : shape_) {
42 layer_->mutable_reshape()->add_targetshape(dim);
43 }
44 if (need_transpose_)
45 layer_->mutable_reshape()->set_mode(
46 CoreML::Specification::ReshapeLayerParams::CHANNEL_LAST);
47 return layer_.release();
48 }
49
SetShapeFromTensor(const TfLiteTensor * output_shape,const TfLiteIntArray * input_shape)50 void ReshapeOpBuilder::SetShapeFromTensor(const TfLiteTensor* output_shape,
51 const TfLiteIntArray* input_shape) {
52 TfLiteIntArray* shape = TfLiteIntArrayCreate(output_shape->dims->data[0]);
53 std::memcpy(shape->data, GetTensorData<int>(output_shape),
54 shape->size * sizeof(int));
55
56 SetShapeFromIntArray(shape, input_shape);
57 TfLiteIntArrayFree(shape);
58 }
59
SetShapeFromIntArray(const TfLiteIntArray * output_shape,const TfLiteIntArray * input_shape)60 void ReshapeOpBuilder::SetShapeFromIntArray(const TfLiteIntArray* output_shape,
61 const TfLiteIntArray* input_shape) {
62 // ignore first dimension (batch)
63 std::copy(output_shape->data + 1, output_shape->data + output_shape->size,
64 std::back_inserter(shape_));
65
66 int64_t reshape_size = 1;
67 int negative_index = -1;
68 for (int i = 0; i < shape_.size(); ++i) {
69 if (shape_[i] == -1) {
70 negative_index = i;
71 } else {
72 reshape_size *= shape_[i];
73 }
74 }
75 if (negative_index >= 0) {
76 int64_t input_size = NumElements(input_shape);
77 shape_[negative_index] = input_size / reshape_size;
78 }
79
80 if (shape_.size() == 2) {
81 shape_ = {shape_[1], 1, shape_[0]};
82 } else if (shape_.size() == 3) {
83 shape_ = {shape_[2], shape_[0], shape_[1]};
84 }
85 // When channel dimension is changed, reshape should be done with HWC layout.
86 if (shape_[0] != input_shape->data[input_shape->size - 1]) {
87 need_transpose_ = true;
88 }
89 }
90
RegisterInputs(const TfLiteIntArray * inputs,TfLiteContext * context)91 TfLiteStatus ReshapeOpBuilder::RegisterInputs(const TfLiteIntArray* inputs,
92 TfLiteContext* context) {
93 AddInput(inputs->data[0]);
94
95 if (inputs->size == 2) {
96 SetShapeFromTensor(&context->tensors[inputs->data[1]],
97 context->tensors[inputs->data[0]].dims);
98 } else {
99 const auto* params = reinterpret_cast<TfLiteReshapeParams*>(builtin_data_);
100 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(params->num_dimensions);
101 std::memcpy(output_shape->data, params->shape,
102 params->num_dimensions * sizeof(int));
103
104 SetShapeFromIntArray(output_shape, context->tensors[inputs->data[0]].dims);
105 TfLiteIntArrayFree(output_shape);
106 }
107 return kTfLiteOk;
108 }
109
RegisterOutputs(const TfLiteIntArray * outputs,TfLiteContext * context)110 TfLiteStatus ReshapeOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs,
111 TfLiteContext* context) {
112 graph_builder_->AddTensorWithID(outputs->data[0], GetOutput(context));
113 return kTfLiteOk;
114 }
115
IsReshapeOpSupported(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context,int coreml_version)116 bool IsReshapeOpSupported(const TfLiteRegistration* registration,
117 const TfLiteNode* node, TfLiteContext* context,
118 int coreml_version) {
119 if (coreml_version >= 3) {
120 return false;
121 }
122 if (node->inputs->size == 1) {
123 const auto* params =
124 reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data);
125 return params->num_dimensions == 3 || params->num_dimensions == 4;
126 }
127
128 const int kShapeTensor = 1;
129 const TfLiteTensor* shape;
130 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShapeTensor, &shape));
131 if (shape->allocation_type != kTfLiteMmapRo) {
132 TF_LITE_KERNEL_LOG(context, "Reshape has non-const shape.");
133 return false;
134 }
135 const bool is_shape_tensor =
136 shape->dims->size == 1 && shape->type == kTfLiteInt32;
137 return is_shape_tensor &&
138 (shape->dims->data[0] == 3 || shape->dims->data[0] == 4);
139 }
140
CreateReshapeOpBuilder(GraphBuilder * graph_builder)141 OpBuilder* CreateReshapeOpBuilder(GraphBuilder* graph_builder) {
142 return new ReshapeOpBuilder(graph_builder);
143 }
144
145 } // namespace coreml
146 } // namespace delegates
147 } // namespace tflite
148