• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/coreml/builders/pad_op_builder.h"
16 
17 #include "tensorflow/lite/builtin_ops.h"
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/delegates/coreml/builders/op_factory.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 
24 namespace tflite {
25 namespace delegates {
26 namespace coreml {
27 
DebugName()28 const std::string& PadOpBuilder::DebugName() {
29   if (!debug_name_.empty()) return debug_name_;
30   SetDebugName(padding_type_ == PadType::kPad ? "PadOpBuilder (PAD)"
31                                               : "PadOpBuilder (MIRROR_PAD)",
32                node_id_);
33   return debug_name_;
34 }
35 
Build()36 CoreML::Specification::NeuralNetworkLayer* PadOpBuilder::Build() {
37   layer_->set_name(DebugName());
38   if (padding_type_ == PadType::kPad) {
39     layer_->mutable_padding()->mutable_constant();
40   } else if (padding_type_ == PadType::kMirrorPad) {
41     layer_->mutable_padding()->mutable_reflection();
42   }
43   return layer_.release();
44 }
45 
46 // padding is d x 2 tensor, where d is the dimension of input.
47 // only paddings for width and height are considered.
SetPadding(const TfLiteTensor * padding)48 void PadOpBuilder::SetPadding(const TfLiteTensor* padding) {
49   const int32_t* padding_data = GetTensorData<int32_t>(padding);
50   for (int i = 1; i <= 2; ++i) {
51     auto* borderamount = layer_->mutable_padding()
52                              ->mutable_paddingamounts()
53                              ->add_borderamounts();
54     borderamount->set_startedgesize(padding_data[i * 2]);
55     borderamount->set_endedgesize(padding_data[i * 2 + 1]);
56   }
57 }
58 
SetConstantValue(const TfLiteTensor * constant_value)59 void PadOpBuilder::SetConstantValue(const TfLiteTensor* constant_value) {
60   layer_->mutable_padding()->mutable_constant()->set_value(
61       GetTensorData<float>(constant_value)[0]);
62 }
63 
RegisterInputs(const TfLiteIntArray * inputs,TfLiteContext * context)64 TfLiteStatus PadOpBuilder::RegisterInputs(const TfLiteIntArray* inputs,
65                                           TfLiteContext* context) {
66   if (!(inputs->size == 2 || inputs->size == 3)) {
67     TF_LITE_KERNEL_LOG(context, "Wrong # of inputs to Padding!.");
68     return kTfLiteError;
69   }
70   AddInput(inputs->data[0]);
71   SetPadding(GetInput(context, tflite_node_, 1));
72   if (inputs->size == 3) {
73     SetConstantValue(GetInput(context, tflite_node_, 2));
74   }
75 
76   return kTfLiteOk;
77 }
78 
RegisterOutputs(const TfLiteIntArray * outputs,TfLiteContext * context)79 TfLiteStatus PadOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs,
80                                            TfLiteContext* context) {
81   if (outputs->size != 1) {
82     TF_LITE_KERNEL_LOG(context, "Wrong # of outputs to Padding!.");
83     return kTfLiteError;
84   }
85   graph_builder_->AddTensorWithID(outputs->data[0], GetOutput(context));
86   return kTfLiteOk;
87 }
88 
CreatePadOpBuilder(GraphBuilder * graph_builder)89 OpBuilder* CreatePadOpBuilder(GraphBuilder* graph_builder) {
90   return new PadOpBuilder(graph_builder, PadType::kPad);
91 }
92 
CreateMirrorPadOpBuilder(GraphBuilder * graph_builder)93 OpBuilder* CreateMirrorPadOpBuilder(GraphBuilder* graph_builder) {
94   return new PadOpBuilder(graph_builder, PadType::kMirrorPad);
95 }
96 
IsPadOpSupported(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context)97 bool IsPadOpSupported(const TfLiteRegistration* registration,
98                       const TfLiteNode* node, TfLiteContext* context) {
99   // padding is d x 2 tensor, where d is the dimension of input.
100   const TfLiteTensor* padding;
101   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &padding));
102   if (!IsConstantTensor(padding)) {
103     TF_LITE_KERNEL_LOG(context,
104                        "%s: Only constant padding is supported for PAD.",
105                        padding->name);
106     return false;
107   }
108   if (padding->dims->data[0] != 4 || padding->dims->data[1] != 2) {
109     TF_LITE_KERNEL_LOG(context, "%s: Only 4D inputs are supported for PAD.",
110                        padding->name);
111     return false;
112   }
113   const int32_t* padding_data = GetTensorData<int32_t>(padding);
114   if (!(padding_data[0] == 0 && padding_data[1] == 0)) {
115     TF_LITE_KERNEL_LOG(
116         context, "%s: Padding for batch dimension is not supported in PAD.",
117         padding->name);
118     return false;
119   }
120 
121   if (!(padding_data[6] == 0 && padding_data[7] == 0)) {
122     TF_LITE_KERNEL_LOG(
123         context, "%s: Padding for channel dimension is not supported in PAD.",
124         padding->name);
125     return false;
126   }
127   return true;
128 }
129 
IsMirrorPadOpSupported(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context)130 bool IsMirrorPadOpSupported(const TfLiteRegistration* registration,
131                             const TfLiteNode* node, TfLiteContext* context) {
132   auto* params =
133       reinterpret_cast<TfLiteMirrorPaddingParams*>(node->builtin_data);
134   if (params->mode != kTfLiteMirrorPaddingReflect) {
135     TF_LITE_KERNEL_LOG(context,
136                        "Only REFLECT mode is supported for MIRROR_PAD.");
137     return false;
138   }
139   return IsPadOpSupported(registration, node, context);
140 }
141 
142 }  // namespace coreml
143 }  // namespace delegates
144 }  // namespace tflite
145