• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/coreml/builders/reshape_op_builder.h"
16 
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/delegates/coreml/builders/op_builder.h"
20 #include "tensorflow/lite/delegates/coreml/builders/op_factory.h"
21 #include "tensorflow/lite/delegates/coreml/builders/op_validator.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 
25 namespace tflite {
26 namespace delegates {
27 namespace coreml {
28 
DebugName()29 const std::string& ReshapeOpBuilder::DebugName() {
30   if (debug_name_.empty()) {
31     SetDebugName("ReshapeOpBuilder", node_id_);
32   }
33   return debug_name_;
34 }
35 
Build()36 CoreML::Specification::NeuralNetworkLayer* ReshapeOpBuilder::Build() {
37   if (layer_ == nullptr) {
38     layer_.reset(new CoreML::Specification::NeuralNetworkLayer);
39   }
40   layer_->set_name(DebugName());
41   for (int dim : shape_) {
42     layer_->mutable_reshape()->add_targetshape(dim);
43   }
44   if (need_transpose_)
45     layer_->mutable_reshape()->set_mode(
46         CoreML::Specification::ReshapeLayerParams::CHANNEL_LAST);
47   return layer_.release();
48 }
49 
SetShapeFromTensor(const TfLiteTensor * output_shape,const TfLiteIntArray * input_shape)50 void ReshapeOpBuilder::SetShapeFromTensor(const TfLiteTensor* output_shape,
51                                           const TfLiteIntArray* input_shape) {
52   TfLiteIntArray* shape = TfLiteIntArrayCreate(output_shape->dims->data[0]);
53   std::memcpy(shape->data, GetTensorData<int>(output_shape),
54               shape->size * sizeof(int));
55 
56   SetShapeFromIntArray(shape, input_shape);
57   TfLiteIntArrayFree(shape);
58 }
59 
SetShapeFromIntArray(const TfLiteIntArray * output_shape,const TfLiteIntArray * input_shape)60 void ReshapeOpBuilder::SetShapeFromIntArray(const TfLiteIntArray* output_shape,
61                                             const TfLiteIntArray* input_shape) {
62   // ignore first dimension (batch)
63   std::copy(output_shape->data + 1, output_shape->data + output_shape->size,
64             std::back_inserter(shape_));
65 
66   int64_t reshape_size = 1;
67   int negative_index = -1;
68   for (int i = 0; i < shape_.size(); ++i) {
69     if (shape_[i] == -1) {
70       negative_index = i;
71     } else {
72       reshape_size *= shape_[i];
73     }
74   }
75   if (negative_index >= 0) {
76     int64_t input_size = NumElements(input_shape);
77     shape_[negative_index] = input_size / reshape_size;
78   }
79 
80   if (shape_.size() == 2) {
81     shape_ = {shape_[1], 1, shape_[0]};
82   } else if (shape_.size() == 3) {
83     shape_ = {shape_[2], shape_[0], shape_[1]};
84   }
85   // When channel dimension is changed, reshape should be done with HWC layout.
86   if (shape_[0] != input_shape->data[input_shape->size - 1]) {
87     need_transpose_ = true;
88   }
89 }
90 
RegisterInputs(const TfLiteIntArray * inputs,TfLiteContext * context)91 TfLiteStatus ReshapeOpBuilder::RegisterInputs(const TfLiteIntArray* inputs,
92                                               TfLiteContext* context) {
93   AddInput(inputs->data[0]);
94 
95   if (inputs->size == 2) {
96     SetShapeFromTensor(&context->tensors[inputs->data[1]],
97                        context->tensors[inputs->data[0]].dims);
98   } else {
99     const auto* params = reinterpret_cast<TfLiteReshapeParams*>(builtin_data_);
100     TfLiteIntArray* output_shape = TfLiteIntArrayCreate(params->num_dimensions);
101     std::memcpy(output_shape->data, params->shape,
102                 params->num_dimensions * sizeof(int));
103 
104     SetShapeFromIntArray(output_shape, context->tensors[inputs->data[0]].dims);
105     TfLiteIntArrayFree(output_shape);
106   }
107   return kTfLiteOk;
108 }
109 
RegisterOutputs(const TfLiteIntArray * outputs,TfLiteContext * context)110 TfLiteStatus ReshapeOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs,
111                                                TfLiteContext* context) {
112   graph_builder_->AddTensorWithID(outputs->data[0], GetOutput(context));
113   return kTfLiteOk;
114 }
115 
IsReshapeOpSupported(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context,int coreml_version)116 bool IsReshapeOpSupported(const TfLiteRegistration* registration,
117                           const TfLiteNode* node, TfLiteContext* context,
118                           int coreml_version) {
119   if (coreml_version >= 3) {
120     return false;
121   }
122   if (node->inputs->size == 1) {
123     const auto* params =
124         reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data);
125     return params->num_dimensions == 3 || params->num_dimensions == 4;
126   }
127 
128   const int kShapeTensor = 1;
129   const TfLiteTensor* shape;
130   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShapeTensor, &shape));
131   if (shape->allocation_type != kTfLiteMmapRo) {
132     TF_LITE_KERNEL_LOG(context, "Reshape has non-const shape.");
133     return false;
134   }
135   const bool is_shape_tensor =
136       shape->dims->size == 1 && shape->type == kTfLiteInt32;
137   return is_shape_tensor &&
138          (shape->dims->data[0] == 3 || shape->dims->data[0] == 4);
139 }
140 
CreateReshapeOpBuilder(GraphBuilder * graph_builder)141 OpBuilder* CreateReshapeOpBuilder(GraphBuilder* graph_builder) {
142   return new ReshapeOpBuilder(graph_builder);
143 }
144 
145 }  // namespace coreml
146 }  // namespace delegates
147 }  // namespace tflite
148