1 /* Copyright 2019 Google LLC. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17 #include <iostream>
18
19 #include "ruy/ruy.h"
20
ExampleMulFloat(ruy::Context * context)21 void ExampleMulFloat(ruy::Context *context) {
22 const float lhs_data[] = {1, 2, 3, 4};
23 const float rhs_data[] = {1, 2, 3, 4};
24 float dst_data[4];
25
26 ruy::Matrix<float> lhs;
27 ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
28 lhs.set_data(lhs_data);
29 ruy::Matrix<float> rhs;
30 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
31 rhs.set_data(rhs_data);
32 ruy::Matrix<float> dst;
33 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
34 dst.set_data(dst_data);
35
36 ruy::MulParams<float, float> mul_params;
37 ruy::Mul(lhs, rhs, mul_params, context, &dst);
38
39 std::cout << "Example Mul, float:\n";
40 std::cout << "LHS:\n" << lhs;
41 std::cout << "RHS:\n" << rhs;
42 std::cout << "Result:\n" << dst << "\n";
43 }
44
ExampleMulFloatWithBiasAddAndClamp(ruy::Context * context)45 void ExampleMulFloatWithBiasAddAndClamp(ruy::Context *context) {
46 const float lhs_data[] = {1, 2, 3, 4};
47 const float rhs_data[] = {1, 2, 3, 4};
48 const float bias_data[] = {1, 0};
49 float dst_data[4];
50
51 ruy::Matrix<float> lhs;
52 ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
53 lhs.set_data(lhs_data);
54 ruy::Matrix<float> rhs;
55 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
56 rhs.set_data(rhs_data);
57 ruy::Matrix<float> dst;
58 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
59 dst.set_data(dst_data);
60
61 ruy::MulParams<float, float> mul_params;
62 mul_params.set_bias(bias_data);
63 mul_params.set_clamp_min(0);
64 mul_params.set_clamp_max(15);
65 ruy::Mul(lhs, rhs, mul_params, context, &dst);
66
67 std::cout << "Example Mul, float with bias addition and clamp:\n";
68 std::cout << "LHS:\n" << lhs;
69 std::cout << "RHS:\n" << rhs;
70 std::cout << "Result:\n" << dst << "\n";
71 }
72
ExampleMulUint8AsymmetricQuantized(ruy::Context * context)73 void ExampleMulUint8AsymmetricQuantized(ruy::Context *context) {
74 const std::uint8_t lhs_data[] = {124, 125, 126, 127};
75 const std::uint8_t rhs_data[] = {129, 130, 131, 132};
76 std::uint8_t dst_data[4];
77
78 ruy::Matrix<std::uint8_t> lhs;
79 ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
80 lhs.set_data(lhs_data);
81 lhs.set_zero_point(125);
82 ruy::Matrix<std::uint8_t> rhs;
83 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
84 rhs.set_data(rhs_data);
85 rhs.set_zero_point(132);
86 ruy::Matrix<std::uint8_t> dst;
87 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
88 dst.set_data(dst_data);
89 dst.set_zero_point(129);
90
91 ruy::MulParams<std::int32_t, std::uint8_t> mul_params;
92 mul_params.set_multiplier_fixedpoint(1 << 30);
93
94 mul_params.set_multiplier_exponent(0);
95 ruy::Mul(lhs, rhs, mul_params, context, &dst);
96
97 std::cout << "Example Mul, uint8 quantized with asymmetric zero points:\n";
98 std::cout << "LHS:\n" << lhs;
99 std::cout << "RHS:\n" << rhs;
100 std::cout << "Result:\n" << dst << "\n";
101 }
ExampleMulInt8PerChannelQuantized(ruy::Context * context)102 void ExampleMulInt8PerChannelQuantized(ruy::Context *context) {
103 const std::int8_t lhs_data[] = {1, 2, 3, 4};
104 const std::int8_t rhs_data[] = {1, 2, 3, 4};
105 const std::int32_t multiplier_data[] = {3 << 28, 5 << 28};
106 const int exponent_data[] = {1, -2};
107 std::int8_t dst_data[4];
108
109 ruy::Matrix<std::int8_t> lhs;
110 ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
111 lhs.set_data(lhs_data);
112 ruy::Matrix<std::int8_t> rhs;
113 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
114 rhs.set_data(rhs_data);
115 ruy::Matrix<std::int8_t> dst;
116 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
117 dst.set_data(dst_data);
118
119 ruy::MulParams<std::int32_t, std::int8_t> mul_params;
120 mul_params.set_multiplier_fixedpoint_perchannel(multiplier_data);
121 mul_params.set_multiplier_exponent_perchannel(exponent_data);
122 ruy::Mul(lhs, rhs, mul_params, context, &dst);
123
124 std::cout << "Example Mul, int8 quantized with per-channel multipliers\n";
125 std::cout << "LHS:\n" << lhs;
126 std::cout << "RHS:\n" << rhs;
127 std::cout << "Result:\n" << dst << "\n";
128 }
ExampleMulInt8GetRawAccumulators(ruy::Context * context)129 void ExampleMulInt8GetRawAccumulators(ruy::Context *context) {
130 const std::int8_t lhs_data[] = {1, 2, 3, 4};
131 const std::int8_t rhs_data[] = {1, 2, 3, 4};
132 std::int32_t dst_data[4];
133
134 ruy::Matrix<std::int8_t> lhs;
135 ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
136 lhs.set_data(lhs_data);
137 ruy::Matrix<std::int8_t> rhs;
138 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
139 rhs.set_data(rhs_data);
140 ruy::Matrix<std::int32_t> dst;
141 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
142 dst.set_data(dst_data);
143
144 // When Dst is int32, mul_params is unused.
145 ruy::MulParams<std::int32_t, std::int32_t> mul_params;
146 ruy::Mul(lhs, rhs, mul_params, context, &dst);
147
148 std::cout << "Example Mul, returning raw int32 accumulators:\n";
149 std::cout << "LHS:\n" << lhs;
150 std::cout << "RHS:\n" << rhs;
151 std::cout << "Result:\n" << dst << "\n";
152 }
153
main()154 int main() {
155 ruy::Context context;
156 ExampleMulFloat(&context);
157 ExampleMulFloatWithBiasAddAndClamp(&context);
158 ExampleMulUint8AsymmetricQuantized(&context);
159 ExampleMulInt8PerChannelQuantized(&context);
160 ExampleMulInt8GetRawAccumulators(&context);
161 }
162