• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/prelu.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "tensorflow/lite/delegates/gpu/common/convert.h"
26 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
27 #include "tensorflow/lite/delegates/gpu/common/shape.h"
28 #include "tensorflow/lite/delegates/gpu/common/status.h"
29 #include "tensorflow/lite/delegates/gpu/common/types.h"
30 
31 namespace tflite {
32 namespace gpu {
33 namespace gl {
34 namespace {
35 
36 class PReLULinearAlpha : public NodeShader {
37  public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const38   absl::Status GenerateCode(const GenerationContext& ctx,
39                             GeneratedCode* generated_code) const final {
40     const auto& attr = absl::any_cast<const PReLUAttributes&>(ctx.op_attr);
41     auto alpha = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.alpha);
42     if (!alpha) {
43       return absl::InvalidArgumentError("Alpha is missing");
44     }
45     if (alpha->shape.v != ctx.output_shapes[0][3]) {
46       return absl::InvalidArgumentError(
47           "Alpha shape does not match the number of channels.");
48     }
49 
50     *generated_code =
51         attr.clip
52             ? GeneratedCode{
53                   /*parameters=*/{{"clip", attr.clip}},
54                   /*objects=*/{{"alpha", MakeReadonlyObject(alpha->data)}},
55                   /*shared_variables=*/{},
56                   /*workload=*/uint3(),
57                   /*workgroup=*/uint3(),
58                   "value_0 = clamp(value_0, 0.0, $clip$) + $alpha[gid.z]$ * "
59                   "min(value_0, 0.0);",
60                   /*input=*/IOStructure::AUTO,
61                   /*output=*/IOStructure::AUTO,
62               }
63             : GeneratedCode{
64                   /*parameters=*/{},
65                   /*objects=*/{{"alpha", MakeReadonlyObject(alpha->data)}},
66                   /*shared_variables=*/{},
67                   // Declare workload explicitly because shader depends on
68                   // gid.z.
69                   /*workload=*/
70                   uint3(static_cast<int>(ctx.output_shapes[0][2]),
71                         static_cast<int>(ctx.output_shapes[0][1]),
72                         DivideRoundUp(static_cast<int>(ctx.output_shapes[0][3]),
73                                       4)),
74                   /*workgroup=*/uint3(),
75                   /*source_code=*/
76                   "value_0 = max(value_0, 0.0) + $alpha[gid.z]$ * min(value_0, "
77                   "0.0);",
78                   /*input=*/IOStructure::AUTO,
79                   /*output=*/IOStructure::AUTO,
80               };
81     return absl::OkStatus();
82   }
83 };
84 
85 class PReLUFull : public NodeShader {
86  public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const87   absl::Status GenerateCode(const GenerationContext& ctx,
88                             GeneratedCode* generated_code) const final {
89     const auto& attr = absl::any_cast<const PReLUAttributes&>(ctx.op_attr);
90     auto alpha = absl::get_if<Tensor<HWC, DataType::FLOAT32>>(&attr.alpha);
91     if (!alpha) {
92       return absl::InvalidArgumentError("Alpha is missing");
93     }
94     if (alpha->shape.h != ctx.output_shapes[0][1] ||
95         alpha->shape.w != ctx.output_shapes[0][2] ||
96         alpha->shape.c != ctx.output_shapes[0][3]) {
97       return absl::InvalidArgumentError(
98           "Alpha shape does not match input shape.");
99     }
100 
101     ObjectSize obj_size =
102         uint3(static_cast<int>(ctx.output_shapes[0][2]),
103               static_cast<int>(ctx.output_shapes[0][1]),
104               DivideRoundUp(static_cast<int>(ctx.output_shapes[0][3]), 4));
105 
106     *generated_code =
107         attr.clip
108             ? GeneratedCode{
109                   /*parameters=*/{{"clip", attr.clip}},
110                   /*objects=*/
111                   {{"alpha",
112                     MakeReadonlyObject(obj_size, ConvertToPHWC4(*alpha))}},
113                   /*shared_variables=*/{},
114                   // Declare workload explicitly because shader
115                   // depends on gid.z.
116                   /*workload=*/
117                   uint3(static_cast<int>(ctx.output_shapes[0][2]),
118                         static_cast<int>(ctx.output_shapes[0][1]),
119                         DivideRoundUp(static_cast<int>(ctx.output_shapes[0][3]),
120                                       4)),
121                   /*workgroup=*/uint3(),
122                   /*source_code=*/
123                   "value_0 = clamp(value_0, 0.0, $clip$) + "
124                   "$alpha[gid.x, gid.y, gid.z]$ * min(value_0, 0.0);",
125                   /*input=*/IOStructure::AUTO,
126                   /*output=*/IOStructure::AUTO,
127               }
128             : GeneratedCode{
129                   /*parameters=*/{},
130                   /*objects=*/
131                   {{"alpha",
132                     MakeReadonlyObject(obj_size, ConvertToPHWC4(*alpha))}},
133                   /*shared_variables=*/{},
134                   // Declare workload explicitly because shader depends on
135                   // gid.z.
136                   /*workload=*/
137                   uint3(static_cast<int>(ctx.output_shapes[0][2]),
138                         static_cast<int>(ctx.output_shapes[0][1]),
139                         DivideRoundUp(static_cast<int>(ctx.output_shapes[0][3]),
140                                       4)),
141                   /*workgroup=*/uint3(),
142                   /*source_code=*/
143                   "value_0 = max(value_0, 0.0) + $alpha[gid.x, gid.y, gid.z]$ "
144                   "* min(value_0, 0.0);",
145                   /*input=*/IOStructure::AUTO,
146                   /*output=*/IOStructure::AUTO,
147               };
148     return absl::OkStatus();
149   }
150 };
151 
152 class PReLU : public NodeShader {
153  public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const154   absl::Status GenerateCode(const GenerationContext& ctx,
155                             GeneratedCode* generated_code) const final {
156     const auto& attr = absl::any_cast<const PReLUAttributes&>(ctx.op_attr);
157     auto* alpha = absl::get_if<Tensor<HWC, DataType::FLOAT32>>(&attr.alpha);
158     return alpha ? full_.GenerateCode(ctx, generated_code)
159                  : linear_.GenerateCode(ctx, generated_code);
160   }
161 
162  private:
163   PReLULinearAlpha linear_;
164   PReLUFull full_;
165 };
166 
167 }  // namespace
168 
NewPReLUNodeShader()169 std::unique_ptr<NodeShader> NewPReLUNodeShader() {
170   return absl::make_unique<PReLU>();
171 }
172 
173 }  // namespace gl
174 }  // namespace gpu
175 }  // namespace tflite
176