1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/mul.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <string>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/lite/delegates/gpu/common/convert.h"
27 #include "tensorflow/lite/delegates/gpu/common/status.h"
28 #include "tensorflow/lite/delegates/gpu/common/types.h"
29
30 namespace tflite {
31 namespace gpu {
32 namespace gl {
33
34 namespace {
35
IsApplyMaskSupported(const NodeShader::GenerationContext & ctx)36 bool IsApplyMaskSupported(const NodeShader::GenerationContext& ctx) {
37 if (ctx.input_shapes.size() != 2) return false;
38
39 // [H, W, C] x [H, W, 0][0]
40 if (ctx.input_shapes[0][1] == ctx.input_shapes[1][1] &&
41 ctx.input_shapes[0][2] == ctx.input_shapes[1][2] &&
42 ctx.input_shapes[1][3] == 1) {
43 return true;
44 }
45
46 // [H, W, C] x [H, W, C]
47 if (ctx.input_shapes[0] == ctx.input_shapes[1]) return true;
48
49 // [H, W, C] x [0, 0, C]
50 return ctx.input_shapes[1][1] == 1 && ctx.input_shapes[1][2] == 1 &&
51 ctx.input_shapes[0][3] == ctx.input_shapes[1][3];
52 }
53
GenerateApplyMaskCode(const NodeShader::GenerationContext & ctx,GeneratedCode * generated_code)54 absl::Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx,
55 GeneratedCode* generated_code) {
56 std::string source = "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * ";
57 if (ctx.input_shapes[1][3] == 1) {
58 // [H, W, C] x [H, W, 0][0]
59 absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, 0]$.x;");
60 } else if (ctx.input_shapes[0][1] == ctx.input_shapes[1][1] &&
61 ctx.input_shapes[0][2] == ctx.input_shapes[1][2]) {
62 // [H, W, C] x [H, W, C]
63 absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, gid.z]$;");
64 } else {
65 // [H, W, C] x [0, 0, C]
66 absl::StrAppend(&source, "$input_data_1[0, 0, gid.z]$;");
67 }
68
69 *generated_code = {
70 /*parameters=*/{},
71 /*objects=*/{},
72 /*shared_variables=*/{},
73 /*workload=*/uint3(),
74 /*workgroup=*/uint3(),
75 /*source_code=*/std::move(source),
76 /*input=*/IOStructure::ONLY_DEFINITIONS,
77 /*output=*/IOStructure::AUTO,
78 };
79 return absl::OkStatus();
80 }
81
GenerateMultiplyScalarCode(const NodeShader::GenerationContext & ctx,GeneratedCode * generated_code)82 absl::Status GenerateMultiplyScalarCode(
83 const NodeShader::GenerationContext& ctx, GeneratedCode* generated_code) {
84 const auto& attr = absl::any_cast<const ElementwiseAttributes&>(ctx.op_attr);
85
86 if (absl::holds_alternative<float>(attr.param)) {
87 *generated_code = {
88 /*parameters=*/{{"scalar", absl::get<float>(attr.param)}},
89 /*objects=*/{},
90 /*shared_variables=*/{},
91 /*workload=*/uint3(),
92 /*workgroup=*/uint3(),
93 /*source_code=*/"value_0 *= $scalar$;",
94 /*input=*/IOStructure::AUTO,
95 /*output=*/IOStructure::AUTO,
96 };
97 return absl::OkStatus();
98 }
99
100 if (absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(attr.param)) {
101 *generated_code = {
102 /*parameters=*/{},
103 /*objects=*/
104 {{"mul_buffer",
105 MakeReadonlyObject(
106 absl::get<Tensor<Linear, DataType::FLOAT32>>(attr.param).data)}},
107 /*shared_variables=*/{},
108 // Declare workload explicitly because shader depends on gid.z.
109 /*workload=*/
110 uint3(static_cast<int>(ctx.input_shapes[0][2]),
111 static_cast<int>(ctx.input_shapes[0][1]),
112 DivideRoundUp(static_cast<int>(ctx.input_shapes[0][3]), 4)),
113 /*workgroup=*/uint3(),
114 /*source_code=*/"value_0 *= $mul_buffer[gid.z]$;",
115 /*input=*/IOStructure::AUTO,
116 /*output=*/IOStructure::AUTO,
117 };
118 return absl::OkStatus();
119 }
120
121 if (absl::holds_alternative<Tensor<HWC, DataType::FLOAT32>>(attr.param)) {
122 *generated_code = {
123 /*parameters=*/{},
124 /*objects=*/
125 {{"hwc_buffer",
126 MakeReadonlyObject(
127 uint3(static_cast<int>(ctx.input_shapes[0][2]),
128 static_cast<int>(ctx.input_shapes[0][1]),
129 DivideRoundUp(static_cast<int>(ctx.input_shapes[0][3]), 4)),
130 ConvertToPHWC4(
131 absl::get<Tensor<HWC, DataType::FLOAT32>>(attr.param)))}},
132 /*shared_variables=*/{},
133 // Declare workload explicitly because shader depends on gid.z.
134 /*workload=*/
135 uint3(static_cast<int>(ctx.input_shapes[0][2]),
136 static_cast<int>(ctx.input_shapes[0][1]),
137 DivideRoundUp(static_cast<int>(ctx.input_shapes[0][3]), 4)),
138 /*workgroup=*/uint3(),
139 /*source_code=*/"value_0 *= $hwc_buffer[gid.x, gid.y, gid.z]$;",
140 /*input=*/IOStructure::AUTO,
141 /*output=*/IOStructure::AUTO,
142 };
143 return absl::OkStatus();
144 }
145
146 return absl::InvalidArgumentError("Unsupported Multiplication case.");
147 }
148
149 class Multiply : public NodeShader {
150 public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const151 absl::Status GenerateCode(const GenerationContext& ctx,
152 GeneratedCode* generated_code) const final {
153 if (IsApplyMaskSupported(ctx)) {
154 return GenerateApplyMaskCode(ctx, generated_code);
155 } else {
156 return GenerateMultiplyScalarCode(ctx, generated_code);
157 }
158 }
159 };
160
161 } // namespace
162
NewMultiplyNodeShader()163 std::unique_ptr<NodeShader> NewMultiplyNodeShader() {
164 return absl::make_unique<Multiply>();
165 }
166
167 } // namespace gl
168 } // namespace gpu
169 } // namespace tflite
170