• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/mul.h"
17 
18 #include <algorithm>
19 #include <any>
20 #include <cstdint>
21 #include <cstring>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <variant>
26 #include <vector>
27 
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "tensorflow/lite/delegates/gpu/common/convert.h"
31 #include "tensorflow/lite/delegates/gpu/common/status.h"
32 #include "tensorflow/lite/delegates/gpu/common/types.h"
33 
34 namespace tflite {
35 namespace gpu {
36 namespace gl {
37 
38 namespace {
39 
IsApplyMaskSupported(const NodeShader::GenerationContext & ctx)40 bool IsApplyMaskSupported(const NodeShader::GenerationContext& ctx) {
41   if (ctx.input_shapes.size() != 2) return false;
42 
43   // [H, W, C] x [H, W, 0][0]
44   if (ctx.input_shapes[0][1] == ctx.input_shapes[1][1] &&
45       ctx.input_shapes[0][2] == ctx.input_shapes[1][2] &&
46       ctx.input_shapes[1][3] == 1) {
47     return true;
48   }
49 
50   // [H, W, C] x [H, W, C]
51   if (ctx.input_shapes[0] == ctx.input_shapes[1]) return true;
52 
53   // [H, W, C] x [0, 0, C]
54   return ctx.input_shapes[1][1] == 1 && ctx.input_shapes[1][2] == 1 &&
55          ctx.input_shapes[0][3] == ctx.input_shapes[1][3];
56 }
57 
GenerateApplyMaskCode(const NodeShader::GenerationContext & ctx,GeneratedCode * generated_code)58 absl::Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx,
59                                    GeneratedCode* generated_code) {
60   std::string source = "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * ";
61   if (ctx.input_shapes[1][3] == 1) {
62     // [H, W, C] x [H, W, 0][0]
63     absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, 0]$.x;");
64   } else if (ctx.input_shapes[0][1] == ctx.input_shapes[1][1] &&
65              ctx.input_shapes[0][2] == ctx.input_shapes[1][2]) {
66     // [H, W, C] x [H, W, C]
67     absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, gid.z]$;");
68   } else {
69     // [H, W, C] x [0, 0, C]
70     absl::StrAppend(&source, "$input_data_1[0, 0, gid.z]$;");
71   }
72 
73   *generated_code = {
74       /*parameters=*/{},
75       /*objects=*/{},
76       /*shared_variables=*/{},
77       /*workload=*/uint3(),
78       /*workgroup=*/uint3(),
79       /*source_code=*/std::move(source),
80       /*input=*/IOStructure::ONLY_DEFINITIONS,
81       /*output=*/IOStructure::AUTO,
82   };
83   return absl::OkStatus();
84 }
85 
GenerateMultiplyScalarCode(const NodeShader::GenerationContext & ctx,GeneratedCode * generated_code)86 absl::Status GenerateMultiplyScalarCode(
87     const NodeShader::GenerationContext& ctx, GeneratedCode* generated_code) {
88   const auto& attr = std::any_cast<const ElementwiseAttributes&>(ctx.op_attr);
89 
90   if (std::holds_alternative<float>(attr.param)) {
91     *generated_code = {
92         /*parameters=*/{{"scalar", std::get<float>(attr.param)}},
93         /*objects=*/{},
94         /*shared_variables=*/{},
95         /*workload=*/uint3(),
96         /*workgroup=*/uint3(),
97         /*source_code=*/"value_0 *= $scalar$;",
98         /*input=*/IOStructure::AUTO,
99         /*output=*/IOStructure::AUTO,
100     };
101     return absl::OkStatus();
102   }
103 
104   if (std::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(attr.param)) {
105     *generated_code = {
106         /*parameters=*/{},
107         /*objects=*/
108         {{"mul_buffer",
109           MakeReadonlyObject(
110               std::get<Tensor<Linear, DataType::FLOAT32>>(attr.param).data)}},
111         /*shared_variables=*/{},
112         // Declare workload explicitly because shader depends on gid.z.
113         /*workload=*/
114         uint3(static_cast<int>(ctx.input_shapes[0][2]),
115               static_cast<int>(ctx.input_shapes[0][1]),
116               DivideRoundUp(static_cast<int>(ctx.input_shapes[0][3]), 4)),
117         /*workgroup=*/uint3(),
118         /*source_code=*/"value_0 *= $mul_buffer[gid.z]$;",
119         /*input=*/IOStructure::AUTO,
120         /*output=*/IOStructure::AUTO,
121     };
122     return absl::OkStatus();
123   }
124 
125   if (std::holds_alternative<Tensor<HWC, DataType::FLOAT32>>(attr.param)) {
126     *generated_code = {
127         /*parameters=*/{},
128         /*objects=*/
129         {{"hwc_buffer",
130           MakeReadonlyObject(
131               uint3(static_cast<int>(ctx.input_shapes[0][2]),
132                     static_cast<int>(ctx.input_shapes[0][1]),
133                     DivideRoundUp(static_cast<int>(ctx.input_shapes[0][3]), 4)),
134               ConvertToPHWC4(
135                   std::get<Tensor<HWC, DataType::FLOAT32>>(attr.param)))}},
136         /*shared_variables=*/{},
137         // Declare workload explicitly because shader depends on gid.z.
138         /*workload=*/
139         uint3(static_cast<int>(ctx.input_shapes[0][2]),
140               static_cast<int>(ctx.input_shapes[0][1]),
141               DivideRoundUp(static_cast<int>(ctx.input_shapes[0][3]), 4)),
142         /*workgroup=*/uint3(),
143         /*source_code=*/"value_0 *= $hwc_buffer[gid.x, gid.y, gid.z]$;",
144         /*input=*/IOStructure::AUTO,
145         /*output=*/IOStructure::AUTO,
146     };
147     return absl::OkStatus();
148   }
149 
150   return absl::InvalidArgumentError("Unsupported Multiplication case.");
151 }
152 
153 class Multiply : public NodeShader {
154  public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const155   absl::Status GenerateCode(const GenerationContext& ctx,
156                             GeneratedCode* generated_code) const final {
157     if (IsApplyMaskSupported(ctx)) {
158       return GenerateApplyMaskCode(ctx, generated_code);
159     } else {
160       return GenerateMultiplyScalarCode(ctx, generated_code);
161     }
162   }
163 };
164 
165 }  // namespace
166 
NewMultiplyNodeShader()167 std::unique_ptr<NodeShader> NewMultiplyNodeShader() {
168   return std::make_unique<Multiply>();
169 }
170 
171 }  // namespace gl
172 }  // namespace gpu
173 }  // namespace tflite
174