1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/mul.h"
17
18 #include <algorithm>
19 #include <any>
20 #include <cstdint>
21 #include <cstring>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <variant>
26 #include <vector>
27
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "tensorflow/lite/delegates/gpu/common/convert.h"
31 #include "tensorflow/lite/delegates/gpu/common/status.h"
32 #include "tensorflow/lite/delegates/gpu/common/types.h"
33
34 namespace tflite {
35 namespace gpu {
36 namespace gl {
37
38 namespace {
39
IsApplyMaskSupported(const NodeShader::GenerationContext & ctx)40 bool IsApplyMaskSupported(const NodeShader::GenerationContext& ctx) {
41 if (ctx.input_shapes.size() != 2) return false;
42
43 // [H, W, C] x [H, W, 0][0]
44 if (ctx.input_shapes[0][1] == ctx.input_shapes[1][1] &&
45 ctx.input_shapes[0][2] == ctx.input_shapes[1][2] &&
46 ctx.input_shapes[1][3] == 1) {
47 return true;
48 }
49
50 // [H, W, C] x [H, W, C]
51 if (ctx.input_shapes[0] == ctx.input_shapes[1]) return true;
52
53 // [H, W, C] x [0, 0, C]
54 return ctx.input_shapes[1][1] == 1 && ctx.input_shapes[1][2] == 1 &&
55 ctx.input_shapes[0][3] == ctx.input_shapes[1][3];
56 }
57
GenerateApplyMaskCode(const NodeShader::GenerationContext & ctx,GeneratedCode * generated_code)58 absl::Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx,
59 GeneratedCode* generated_code) {
60 std::string source = "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * ";
61 if (ctx.input_shapes[1][3] == 1) {
62 // [H, W, C] x [H, W, 0][0]
63 absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, 0]$.x;");
64 } else if (ctx.input_shapes[0][1] == ctx.input_shapes[1][1] &&
65 ctx.input_shapes[0][2] == ctx.input_shapes[1][2]) {
66 // [H, W, C] x [H, W, C]
67 absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, gid.z]$;");
68 } else {
69 // [H, W, C] x [0, 0, C]
70 absl::StrAppend(&source, "$input_data_1[0, 0, gid.z]$;");
71 }
72
73 *generated_code = {
74 /*parameters=*/{},
75 /*objects=*/{},
76 /*shared_variables=*/{},
77 /*workload=*/uint3(),
78 /*workgroup=*/uint3(),
79 /*source_code=*/std::move(source),
80 /*input=*/IOStructure::ONLY_DEFINITIONS,
81 /*output=*/IOStructure::AUTO,
82 };
83 return absl::OkStatus();
84 }
85
GenerateMultiplyScalarCode(const NodeShader::GenerationContext & ctx,GeneratedCode * generated_code)86 absl::Status GenerateMultiplyScalarCode(
87 const NodeShader::GenerationContext& ctx, GeneratedCode* generated_code) {
88 const auto& attr = std::any_cast<const ElementwiseAttributes&>(ctx.op_attr);
89
90 if (std::holds_alternative<float>(attr.param)) {
91 *generated_code = {
92 /*parameters=*/{{"scalar", std::get<float>(attr.param)}},
93 /*objects=*/{},
94 /*shared_variables=*/{},
95 /*workload=*/uint3(),
96 /*workgroup=*/uint3(),
97 /*source_code=*/"value_0 *= $scalar$;",
98 /*input=*/IOStructure::AUTO,
99 /*output=*/IOStructure::AUTO,
100 };
101 return absl::OkStatus();
102 }
103
104 if (std::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(attr.param)) {
105 *generated_code = {
106 /*parameters=*/{},
107 /*objects=*/
108 {{"mul_buffer",
109 MakeReadonlyObject(
110 std::get<Tensor<Linear, DataType::FLOAT32>>(attr.param).data)}},
111 /*shared_variables=*/{},
112 // Declare workload explicitly because shader depends on gid.z.
113 /*workload=*/
114 uint3(static_cast<int>(ctx.input_shapes[0][2]),
115 static_cast<int>(ctx.input_shapes[0][1]),
116 DivideRoundUp(static_cast<int>(ctx.input_shapes[0][3]), 4)),
117 /*workgroup=*/uint3(),
118 /*source_code=*/"value_0 *= $mul_buffer[gid.z]$;",
119 /*input=*/IOStructure::AUTO,
120 /*output=*/IOStructure::AUTO,
121 };
122 return absl::OkStatus();
123 }
124
125 if (std::holds_alternative<Tensor<HWC, DataType::FLOAT32>>(attr.param)) {
126 *generated_code = {
127 /*parameters=*/{},
128 /*objects=*/
129 {{"hwc_buffer",
130 MakeReadonlyObject(
131 uint3(static_cast<int>(ctx.input_shapes[0][2]),
132 static_cast<int>(ctx.input_shapes[0][1]),
133 DivideRoundUp(static_cast<int>(ctx.input_shapes[0][3]), 4)),
134 ConvertToPHWC4(
135 std::get<Tensor<HWC, DataType::FLOAT32>>(attr.param)))}},
136 /*shared_variables=*/{},
137 // Declare workload explicitly because shader depends on gid.z.
138 /*workload=*/
139 uint3(static_cast<int>(ctx.input_shapes[0][2]),
140 static_cast<int>(ctx.input_shapes[0][1]),
141 DivideRoundUp(static_cast<int>(ctx.input_shapes[0][3]), 4)),
142 /*workgroup=*/uint3(),
143 /*source_code=*/"value_0 *= $hwc_buffer[gid.x, gid.y, gid.z]$;",
144 /*input=*/IOStructure::AUTO,
145 /*output=*/IOStructure::AUTO,
146 };
147 return absl::OkStatus();
148 }
149
150 return absl::InvalidArgumentError("Unsupported Multiplication case.");
151 }
152
153 class Multiply : public NodeShader {
154 public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const155 absl::Status GenerateCode(const GenerationContext& ctx,
156 GeneratedCode* generated_code) const final {
157 if (IsApplyMaskSupported(ctx)) {
158 return GenerateApplyMaskCode(ctx, generated_code);
159 } else {
160 return GenerateMultiplyScalarCode(ctx, generated_code);
161 }
162 }
163 };
164
165 } // namespace
166
NewMultiplyNodeShader()167 std::unique_ptr<NodeShader> NewMultiplyNodeShader() {
168 return std::make_unique<Multiply>();
169 }
170
171 } // namespace gl
172 } // namespace gpu
173 } // namespace tflite
174