• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/mul.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/lite/delegates/gpu/common/convert.h"
27 #include "tensorflow/lite/delegates/gpu/common/status.h"
28 #include "tensorflow/lite/delegates/gpu/common/types.h"
29 
30 namespace tflite {
31 namespace gpu {
32 namespace gl {
33 
34 namespace {
35 
IsApplyMaskSupported(const NodeShader::GenerationContext & ctx)36 bool IsApplyMaskSupported(const NodeShader::GenerationContext& ctx) {
37   if (ctx.input_shapes.size() != 2) return false;
38 
39   // [H, W, C] x [H, W, 0][0]
40   if (ctx.input_shapes[0][1] == ctx.input_shapes[1][1] &&
41       ctx.input_shapes[0][2] == ctx.input_shapes[1][2] &&
42       ctx.input_shapes[1][3] == 1) {
43     return true;
44   }
45 
46   // [H, W, C] x [H, W, C]
47   if (ctx.input_shapes[0] == ctx.input_shapes[1]) return true;
48 
49   // [H, W, C] x [0, 0, C]
50   return ctx.input_shapes[1][1] == 1 && ctx.input_shapes[1][2] == 1 &&
51          ctx.input_shapes[0][3] == ctx.input_shapes[1][3];
52 }
53 
GenerateApplyMaskCode(const NodeShader::GenerationContext & ctx,GeneratedCode * generated_code)54 absl::Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx,
55                                    GeneratedCode* generated_code) {
56   std::string source = "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * ";
57   if (ctx.input_shapes[1][3] == 1) {
58     // [H, W, C] x [H, W, 0][0]
59     absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, 0]$.x;");
60   } else if (ctx.input_shapes[0][1] == ctx.input_shapes[1][1] &&
61              ctx.input_shapes[0][2] == ctx.input_shapes[1][2]) {
62     // [H, W, C] x [H, W, C]
63     absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, gid.z]$;");
64   } else {
65     // [H, W, C] x [0, 0, C]
66     absl::StrAppend(&source, "$input_data_1[0, 0, gid.z]$;");
67   }
68 
69   *generated_code = {
70       /*parameters=*/{},
71       /*objects=*/{},
72       /*shared_variables=*/{},
73       /*workload=*/uint3(),
74       /*workgroup=*/uint3(),
75       /*source_code=*/std::move(source),
76       /*input=*/IOStructure::ONLY_DEFINITIONS,
77       /*output=*/IOStructure::AUTO,
78   };
79   return absl::OkStatus();
80 }
81 
GenerateMultiplyScalarCode(const NodeShader::GenerationContext & ctx,GeneratedCode * generated_code)82 absl::Status GenerateMultiplyScalarCode(
83     const NodeShader::GenerationContext& ctx, GeneratedCode* generated_code) {
84   const auto& attr = absl::any_cast<const ElementwiseAttributes&>(ctx.op_attr);
85 
86   if (absl::holds_alternative<float>(attr.param)) {
87     *generated_code = {
88         /*parameters=*/{{"scalar", absl::get<float>(attr.param)}},
89         /*objects=*/{},
90         /*shared_variables=*/{},
91         /*workload=*/uint3(),
92         /*workgroup=*/uint3(),
93         /*source_code=*/"value_0 *= $scalar$;",
94         /*input=*/IOStructure::AUTO,
95         /*output=*/IOStructure::AUTO,
96     };
97     return absl::OkStatus();
98   }
99 
100   if (absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(attr.param)) {
101     *generated_code = {
102         /*parameters=*/{},
103         /*objects=*/
104         {{"mul_buffer",
105           MakeReadonlyObject(
106               absl::get<Tensor<Linear, DataType::FLOAT32>>(attr.param).data)}},
107         /*shared_variables=*/{},
108         // Declare workload explicitly because shader depends on gid.z.
109         /*workload=*/
110         uint3(static_cast<int>(ctx.input_shapes[0][2]),
111               static_cast<int>(ctx.input_shapes[0][1]),
112               DivideRoundUp(static_cast<int>(ctx.input_shapes[0][3]), 4)),
113         /*workgroup=*/uint3(),
114         /*source_code=*/"value_0 *= $mul_buffer[gid.z]$;",
115         /*input=*/IOStructure::AUTO,
116         /*output=*/IOStructure::AUTO,
117     };
118     return absl::OkStatus();
119   }
120 
121   if (absl::holds_alternative<Tensor<HWC, DataType::FLOAT32>>(attr.param)) {
122     *generated_code = {
123         /*parameters=*/{},
124         /*objects=*/
125         {{"hwc_buffer",
126           MakeReadonlyObject(
127               uint3(static_cast<int>(ctx.input_shapes[0][2]),
128                     static_cast<int>(ctx.input_shapes[0][1]),
129                     DivideRoundUp(static_cast<int>(ctx.input_shapes[0][3]), 4)),
130               ConvertToPHWC4(
131                   absl::get<Tensor<HWC, DataType::FLOAT32>>(attr.param)))}},
132         /*shared_variables=*/{},
133         // Declare workload explicitly because shader depends on gid.z.
134         /*workload=*/
135         uint3(static_cast<int>(ctx.input_shapes[0][2]),
136               static_cast<int>(ctx.input_shapes[0][1]),
137               DivideRoundUp(static_cast<int>(ctx.input_shapes[0][3]), 4)),
138         /*workgroup=*/uint3(),
139         /*source_code=*/"value_0 *= $hwc_buffer[gid.x, gid.y, gid.z]$;",
140         /*input=*/IOStructure::AUTO,
141         /*output=*/IOStructure::AUTO,
142     };
143     return absl::OkStatus();
144   }
145 
146   return absl::InvalidArgumentError("Unsupported Multiplication case.");
147 }
148 
149 class Multiply : public NodeShader {
150  public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const151   absl::Status GenerateCode(const GenerationContext& ctx,
152                             GeneratedCode* generated_code) const final {
153     if (IsApplyMaskSupported(ctx)) {
154       return GenerateApplyMaskCode(ctx, generated_code);
155     } else {
156       return GenerateMultiplyScalarCode(ctx, generated_code);
157     }
158   }
159 };
160 
161 }  // namespace
162 
NewMultiplyNodeShader()163 std::unique_ptr<NodeShader> NewMultiplyNodeShader() {
164   return absl::make_unique<Multiply>();
165 }
166 
167 }  // namespace gl
168 }  // namespace gpu
169 }  // namespace tflite
170