1 /* 2 * Copyright (c) 2018-2022 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef ARM_COMPUTE_TEST_LSTM_LAYER_FIXTURE 25 #define ARM_COMPUTE_TEST_LSTM_LAYER_FIXTURE 26 27 #include "tests/Globals.h" 28 #include "tests/framework/Asserts.h" 29 #include "tests/framework/Fixture.h" 30 #include "tests/validation/reference/ActivationLayer.h" 31 #include "tests/validation/reference/ArithmeticOperations.h" 32 #include "tests/validation/reference/ConcatenateLayer.h" 33 #include "tests/validation/reference/FullyConnectedLayer.h" 34 #include "tests/validation/reference/GEMM.h" 35 #include "tests/validation/reference/MeanStdDevNormalizationLayer.h" 36 #include "tests/validation/reference/PixelWiseMultiplication.h" 37 #include "tests/validation/reference/Transpose.h" 38 39 namespace arm_compute 40 { 41 namespace test 42 { 43 namespace validation 44 { 45 template <typename TensorType, typename AccessorType, typename FunctionType, typename FunctionParams, typename T> 46 class LSTMLayerValidationFixture : public framework::Fixture 47 { 48 public: 49 template <typename...> setup(TensorShape input_shape,TensorShape input_weights_shape,TensorShape recurrent_weights_shape,TensorShape cell_bias_shape,TensorShape output_cell_shape,TensorShape output_shape,TensorShape scratch_shape,ActivationLayerInfo info,float cell_threshold,float projection_threshold,DataType data_type,bool projection_opt,bool peephole_opt,bool use_layer_norm)50 void setup(TensorShape input_shape, TensorShape input_weights_shape, TensorShape recurrent_weights_shape, TensorShape cell_bias_shape, TensorShape output_cell_shape, TensorShape output_shape, 51 TensorShape scratch_shape, ActivationLayerInfo info, float cell_threshold, float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt, 52 bool use_layer_norm) 53 { 54 _target = compute_target(input_shape, input_weights_shape, recurrent_weights_shape, cell_bias_shape, output_cell_shape, output_shape, scratch_shape, info, cell_threshold, projection_threshold, 55 data_type, projection_opt, peephole_opt, use_layer_norm); 56 _reference = compute_reference(input_shape, input_weights_shape, recurrent_weights_shape, cell_bias_shape, output_cell_shape, output_shape, scratch_shape, info, cell_threshold, projection_threshold, 57 data_type, projection_opt, peephole_opt, use_layer_norm); 58 } 59 60 protected: 61 template <typename U> fill(U && tensor,int i)62 void fill(U &&tensor, int i) 63 { 64 static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported."); 65 using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type; 66 67 DistributionType distribution{ T(-1.0f), T(1.0f) }; 68 library->fill(tensor, distribution, i); 69 } 70 template <typename U> fill_custom_val(U && tensor,float num,int i)71 void fill_custom_val(U &&tensor, float num, int i) 72 { 73 static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported."); 74 using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type; 75 76 DistributionType distribution{ T(num), T(num) }; 77 library->fill(tensor, distribution, i); 78 } compute_target(const TensorShape & input_shape,const TensorShape & input_weights_shape,const TensorShape & recurrent_weights_shape,const TensorShape & cell_bias_shape,const TensorShape & output_cell_shape,const TensorShape & output_shape,const TensorShape & scratch_shape,ActivationLayerInfo info,float cell_threshold,float projection_threshold,DataType data_type,bool projection_opt,bool peephole_opt,bool use_layer_norm)79 TensorType compute_target(const TensorShape &input_shape, const TensorShape &input_weights_shape, const TensorShape &recurrent_weights_shape, const TensorShape &cell_bias_shape, 80 const TensorShape &output_cell_shape, const TensorShape &output_shape, const TensorShape &scratch_shape, ActivationLayerInfo info, float cell_threshold, 81 float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt, bool use_layer_norm) 82 { 83 const unsigned int num_cells = input_weights_shape.y(); 84 const unsigned int num_outputs = recurrent_weights_shape.x(); 85 86 // Create tensors 87 TensorType input = create_tensor<TensorType>(input_shape, data_type); 88 TensorType input_to_forget_w = create_tensor<TensorType>(input_weights_shape, data_type); 89 TensorType input_to_cell_w = create_tensor<TensorType>(input_weights_shape, data_type); 90 TensorType input_to_output_w = create_tensor<TensorType>(input_weights_shape, data_type); 91 TensorType recurrent_to_forget_w = create_tensor<TensorType>(recurrent_weights_shape, data_type); 92 TensorType recurrent_to_cell_w = create_tensor<TensorType>(recurrent_weights_shape, data_type); 93 TensorType recurrent_to_output_w = create_tensor<TensorType>(recurrent_weights_shape, data_type); 94 TensorType forget_gate_bias = create_tensor<TensorType>(cell_bias_shape, data_type); 95 TensorType cell_bias = create_tensor<TensorType>(cell_bias_shape, data_type); 96 TensorType output_gate_bias = create_tensor<TensorType>(cell_bias_shape, data_type); 97 TensorType output_state_in = create_tensor<TensorType>(output_shape, data_type); 98 TensorType cell_state_in = create_tensor<TensorType>(output_cell_shape, data_type); 99 TensorType scratch = create_tensor<TensorType>(scratch_shape, data_type); 100 TensorType output_state_out = create_tensor<TensorType>(output_shape, data_type); 101 TensorType cell_state_out = create_tensor<TensorType>(output_cell_shape, data_type); 102 TensorType output = create_tensor<TensorType>(output_shape, data_type); 103 TensorType input_to_input_w; 104 TensorType recurrent_to_input_w; 105 TensorType cell_to_input_w; 106 TensorType cell_to_forget_w; 107 TensorType input_gate_bias; 108 TensorType cell_to_output_w; 109 TensorType projection_w; 110 TensorType projection_bias; 111 TensorType input_layer_norm_w; 112 TensorType forget_layer_norm_w; 113 TensorType cell_layer_norm_w; 114 TensorType output_layer_norm_w; 115 116 bool cifg_opt = scratch_shape.x() == cell_bias_shape.x() * 4 ? false : true; 117 118 FunctionParams lstm_params; 119 120 if(!cifg_opt) 121 { 122 input_to_input_w = create_tensor<TensorType>(input_weights_shape, data_type); 123 recurrent_to_input_w = create_tensor<TensorType>(recurrent_weights_shape, data_type); 124 if(peephole_opt) 125 { 126 cell_to_input_w = create_tensor<TensorType>(cell_bias_shape, data_type); 127 } 128 input_gate_bias = create_tensor<TensorType>(cell_bias_shape, data_type); 129 lstm_params.set_cifg_params(&input_to_input_w, &recurrent_to_input_w, &cell_to_input_w, &input_gate_bias); 130 } 131 132 if(peephole_opt) 133 { 134 cell_to_forget_w = create_tensor<TensorType>(cell_bias_shape, data_type); 135 cell_to_output_w = create_tensor<TensorType>(cell_bias_shape, data_type); 136 lstm_params.set_peephole_params(&cell_to_forget_w, &cell_to_output_w); 137 } 138 139 if(projection_opt) 140 { 141 projection_w = create_tensor<TensorType>(TensorShape(num_cells, num_outputs), data_type); 142 projection_bias = create_tensor<TensorType>(TensorShape(num_outputs), data_type); 143 lstm_params.set_projection_params(&projection_w, &projection_bias); 144 } 145 146 if(use_layer_norm) 147 { 148 forget_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type); 149 cell_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type); 150 output_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type); 151 if(!cifg_opt) 152 { 153 input_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type); 154 lstm_params.set_layer_normalization_params(&input_layer_norm_w, &forget_layer_norm_w, &cell_layer_norm_w, &output_layer_norm_w); 155 } 156 else 157 { 158 lstm_params.set_layer_normalization_params(nullptr, &forget_layer_norm_w, &cell_layer_norm_w, &output_layer_norm_w); 159 } 160 } 161 162 // Create and configure function 163 FunctionType lstm; 164 lstm.configure(&input, &input_to_forget_w, &input_to_cell_w, &input_to_output_w, &recurrent_to_forget_w, 165 &recurrent_to_cell_w, &recurrent_to_output_w, &forget_gate_bias, &cell_bias, &output_gate_bias, 166 &output_state_in, &cell_state_in, 167 &scratch, &output_state_out, &cell_state_out, &output, 168 lstm_params, info, cell_threshold, projection_threshold); 169 170 ARM_COMPUTE_ASSERT(input.info()->is_resizable()); 171 ARM_COMPUTE_ASSERT(input_to_forget_w.info()->is_resizable()); 172 ARM_COMPUTE_ASSERT(input_to_cell_w.info()->is_resizable()); 173 ARM_COMPUTE_ASSERT(input_to_output_w.info()->is_resizable()); 174 ARM_COMPUTE_ASSERT(recurrent_to_forget_w.info()->is_resizable()); 175 ARM_COMPUTE_ASSERT(recurrent_to_cell_w.info()->is_resizable()); 176 ARM_COMPUTE_ASSERT(recurrent_to_output_w.info()->is_resizable()); 177 ARM_COMPUTE_ASSERT(forget_gate_bias.info()->is_resizable()); 178 ARM_COMPUTE_ASSERT(cell_bias.info()->is_resizable()); 179 ARM_COMPUTE_ASSERT(output_gate_bias.info()->is_resizable()); 180 ARM_COMPUTE_ASSERT(output_state_in.info()->is_resizable()); 181 ARM_COMPUTE_ASSERT(cell_state_in.info()->is_resizable()); 182 ARM_COMPUTE_ASSERT(scratch.info()->is_resizable()); 183 ARM_COMPUTE_ASSERT(output_state_out.info()->is_resizable()); 184 ARM_COMPUTE_ASSERT(cell_state_out.info()->is_resizable()); 185 ARM_COMPUTE_ASSERT(output.info()->is_resizable()); 186 187 // Allocate tensors 188 input.allocator()->allocate(); 189 input_to_forget_w.allocator()->allocate(); 190 input_to_cell_w.allocator()->allocate(); 191 input_to_output_w.allocator()->allocate(); 192 recurrent_to_forget_w.allocator()->allocate(); 193 recurrent_to_cell_w.allocator()->allocate(); 194 recurrent_to_output_w.allocator()->allocate(); 195 forget_gate_bias.allocator()->allocate(); 196 cell_bias.allocator()->allocate(); 197 output_gate_bias.allocator()->allocate(); 198 output_state_in.allocator()->allocate(); 199 cell_state_in.allocator()->allocate(); 200 scratch.allocator()->allocate(); 201 output_state_out.allocator()->allocate(); 202 cell_state_out.allocator()->allocate(); 203 output.allocator()->allocate(); 204 205 ARM_COMPUTE_ASSERT(!input.info()->is_resizable()); 206 ARM_COMPUTE_ASSERT(!input_to_forget_w.info()->is_resizable()); 207 ARM_COMPUTE_ASSERT(!input_to_cell_w.info()->is_resizable()); 208 ARM_COMPUTE_ASSERT(!input_to_output_w.info()->is_resizable()); 209 ARM_COMPUTE_ASSERT(!recurrent_to_forget_w.info()->is_resizable()); 210 ARM_COMPUTE_ASSERT(!recurrent_to_cell_w.info()->is_resizable()); 211 ARM_COMPUTE_ASSERT(!recurrent_to_output_w.info()->is_resizable()); 212 ARM_COMPUTE_ASSERT(!forget_gate_bias.info()->is_resizable()); 213 ARM_COMPUTE_ASSERT(!cell_bias.info()->is_resizable()); 214 ARM_COMPUTE_ASSERT(!output_gate_bias.info()->is_resizable()); 215 ARM_COMPUTE_ASSERT(!output_state_in.info()->is_resizable()); 216 ARM_COMPUTE_ASSERT(!cell_state_in.info()->is_resizable()); 217 ARM_COMPUTE_ASSERT(!scratch.info()->is_resizable()); 218 ARM_COMPUTE_ASSERT(!output_state_out.info()->is_resizable()); 219 ARM_COMPUTE_ASSERT(!cell_state_out.info()->is_resizable()); 220 ARM_COMPUTE_ASSERT(!output.info()->is_resizable()); 221 222 // Fill tensors 223 fill(AccessorType(input), 0); 224 fill(AccessorType(input_to_forget_w), 1); 225 fill(AccessorType(input_to_cell_w), 2); 226 fill(AccessorType(input_to_output_w), 3); 227 fill(AccessorType(recurrent_to_forget_w), 4); 228 fill(AccessorType(recurrent_to_cell_w), 5); 229 fill(AccessorType(recurrent_to_output_w), 6); 230 fill(AccessorType(forget_gate_bias), 7); 231 fill(AccessorType(cell_bias), 8); 232 fill(AccessorType(output_gate_bias), 9); 233 fill(AccessorType(output_state_in), 10); 234 fill(AccessorType(cell_state_in), 11); 235 fill(AccessorType(scratch), 12); 236 237 if(!cifg_opt) 238 { 239 ARM_COMPUTE_ASSERT(input_to_input_w.info()->is_resizable()); 240 ARM_COMPUTE_ASSERT(recurrent_to_input_w.info()->is_resizable()); 241 ARM_COMPUTE_ASSERT(cell_to_input_w.info()->is_resizable()); 242 ARM_COMPUTE_ASSERT(input_gate_bias.info()->is_resizable()); 243 input_to_input_w.allocator()->allocate(); 244 recurrent_to_input_w.allocator()->allocate(); 245 cell_to_input_w.allocator()->allocate(); 246 input_gate_bias.allocator()->allocate(); 247 ARM_COMPUTE_ASSERT(!input_to_input_w.info()->is_resizable()); 248 ARM_COMPUTE_ASSERT(!recurrent_to_input_w.info()->is_resizable()); 249 ARM_COMPUTE_ASSERT(!cell_to_input_w.info()->is_resizable()); 250 ARM_COMPUTE_ASSERT(!input_gate_bias.info()->is_resizable()); 251 fill(AccessorType(input_to_input_w), 13); 252 fill(AccessorType(recurrent_to_input_w), 14); 253 if(peephole_opt) 254 { 255 fill(AccessorType(cell_to_input_w), 15); 256 } 257 fill(AccessorType(recurrent_to_input_w), 16); 258 fill(AccessorType(input_gate_bias), 17); 259 } 260 261 if(peephole_opt) 262 { 263 ARM_COMPUTE_ASSERT(cell_to_forget_w.info()->is_resizable()); 264 ARM_COMPUTE_ASSERT(cell_to_output_w.info()->is_resizable()); 265 cell_to_forget_w.allocator()->allocate(); 266 cell_to_output_w.allocator()->allocate(); 267 ARM_COMPUTE_ASSERT(!cell_to_forget_w.info()->is_resizable()); 268 ARM_COMPUTE_ASSERT(!cell_to_output_w.info()->is_resizable()); 269 fill(AccessorType(cell_to_forget_w), 18); 270 fill(AccessorType(cell_to_output_w), 19); 271 } 272 273 if(projection_opt) 274 { 275 ARM_COMPUTE_ASSERT(projection_w.info()->is_resizable()); 276 ARM_COMPUTE_ASSERT(projection_bias.info()->is_resizable()); 277 278 projection_w.allocator()->allocate(); 279 projection_bias.allocator()->allocate(); 280 281 ARM_COMPUTE_ASSERT(!projection_w.info()->is_resizable()); 282 ARM_COMPUTE_ASSERT(!projection_bias.info()->is_resizable()); 283 284 fill(AccessorType(projection_w), 20); 285 fill(AccessorType(projection_bias), 21); 286 } 287 288 if(use_layer_norm) 289 { 290 if(!cifg_opt) 291 { 292 ARM_COMPUTE_ASSERT(input_layer_norm_w.info()->is_resizable()); 293 294 input_layer_norm_w.allocator()->allocate(); 295 296 ARM_COMPUTE_ASSERT(!input_layer_norm_w.info()->is_resizable()); 297 298 fill(AccessorType(input_layer_norm_w), 22); 299 } 300 ARM_COMPUTE_ASSERT(forget_layer_norm_w.info()->is_resizable()); 301 ARM_COMPUTE_ASSERT(cell_layer_norm_w.info()->is_resizable()); 302 ARM_COMPUTE_ASSERT(output_layer_norm_w.info()->is_resizable()); 303 304 forget_layer_norm_w.allocator()->allocate(); 305 cell_layer_norm_w.allocator()->allocate(); 306 output_layer_norm_w.allocator()->allocate(); 307 308 ARM_COMPUTE_ASSERT(!forget_layer_norm_w.info()->is_resizable()); 309 ARM_COMPUTE_ASSERT(!cell_layer_norm_w.info()->is_resizable()); 310 ARM_COMPUTE_ASSERT(!output_layer_norm_w.info()->is_resizable()); 311 312 fill(AccessorType(forget_layer_norm_w), 23); 313 fill(AccessorType(cell_layer_norm_w), 24); 314 fill(AccessorType(output_layer_norm_w), 25); 315 } 316 317 // Compute function 318 lstm.run(); 319 320 _target_scratch = std::move(scratch); 321 return output; 322 } 323 compute_reference(const TensorShape & input_shape,const TensorShape & input_weights_shape,const TensorShape & recurrent_weights_shape,const TensorShape & cell_bias_shape,const TensorShape & output_cell_shape,const TensorShape & output_shape,const TensorShape & scratch_shape,ActivationLayerInfo info,float cell_threshold,float projection_threshold,DataType data_type,bool projection_opt,bool peephole_opt,bool use_layer_norm)324 SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &input_weights_shape, const TensorShape &recurrent_weights_shape, const TensorShape &cell_bias_shape, 325 const TensorShape &output_cell_shape, const TensorShape &output_shape, const TensorShape &scratch_shape, ActivationLayerInfo info, float cell_threshold, 326 float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt, bool use_layer_norm) 327 { 328 const unsigned int num_cells = input_weights_shape.y(); 329 const unsigned int num_outputs = recurrent_weights_shape.x(); 330 331 // Create projection weights shape 332 TensorShape projection_weights_shape(num_cells, num_outputs); 333 334 // Create projection bias shape 335 TensorShape projection_bias_shape(num_outputs); 336 337 TensorShape gemm_shape{ 1, output_shape.y() }; 338 SimpleTensor<T> gemm_out{ gemm_shape, data_type }; 339 340 // Create reference 341 SimpleTensor<T> input{ input_shape, data_type }; 342 SimpleTensor<T> input_to_input_w{ input_weights_shape, data_type }; 343 SimpleTensor<T> input_to_forget_w{ input_weights_shape, data_type }; 344 SimpleTensor<T> input_to_cell_w{ input_weights_shape, data_type }; 345 SimpleTensor<T> input_to_output_w{ input_weights_shape, data_type }; 346 SimpleTensor<T> recurrent_to_input_w{ recurrent_weights_shape, data_type }; 347 SimpleTensor<T> recurrent_to_forget_w{ recurrent_weights_shape, data_type }; 348 SimpleTensor<T> recurrent_to_cell_w{ recurrent_weights_shape, data_type }; 349 SimpleTensor<T> recurrent_to_output_w{ recurrent_weights_shape, data_type }; 350 SimpleTensor<T> cell_to_input_w{ cell_bias_shape, data_type }; 351 SimpleTensor<T> cell_to_forget_w{ cell_bias_shape, data_type }; 352 SimpleTensor<T> cell_to_output_w{ cell_bias_shape, data_type }; 353 SimpleTensor<T> input_gate_bias{ cell_bias_shape, data_type }; 354 SimpleTensor<T> forget_gate_bias{ cell_bias_shape, data_type }; 355 SimpleTensor<T> cell_bias{ cell_bias_shape, data_type }; 356 SimpleTensor<T> output_gate_bias{ cell_bias_shape, data_type }; 357 SimpleTensor<T> projection_w{ projection_weights_shape, data_type }; 358 SimpleTensor<T> projection_bias{ projection_bias_shape, data_type }; 359 SimpleTensor<T> output_state_in{ output_shape, data_type }; 360 SimpleTensor<T> cell_state_in{ output_cell_shape, data_type }; 361 SimpleTensor<T> scratch{ scratch_shape, data_type }; 362 SimpleTensor<T> output_state_out{ output_shape, data_type }; 363 SimpleTensor<T> cell_state_out{ output_cell_shape, data_type }; 364 SimpleTensor<T> output{ output_shape, data_type }; 365 366 bool cifg_opt = scratch_shape.x() == cell_bias_shape.x() * 4 ? false : true; 367 368 // Fill reference 369 fill(input, 0); 370 fill(input_to_forget_w, 1); 371 fill(input_to_cell_w, 2); 372 fill(input_to_output_w, 3); 373 fill(recurrent_to_forget_w, 4); 374 fill(recurrent_to_cell_w, 5); 375 fill(recurrent_to_output_w, 6); 376 if(use_layer_norm) 377 { 378 fill_custom_val(forget_gate_bias, 0.f, 7); 379 fill_custom_val(cell_bias, 0.f, 8); 380 fill_custom_val(output_gate_bias, 0.f, 9); 381 } 382 else 383 { 384 fill(forget_gate_bias, 7); 385 fill(cell_bias, 8); 386 fill(output_gate_bias, 9); 387 } 388 fill(output_state_in, 10); 389 fill(cell_state_in, 11); 390 fill(scratch, 12); 391 fill(input_to_input_w, 13); 392 fill(recurrent_to_input_w, 14); 393 fill(cell_to_input_w, 15); 394 fill(recurrent_to_input_w, 16); 395 if(!cifg_opt && use_layer_norm) 396 { 397 fill_custom_val(input_gate_bias, 0.f, 17); 398 } 399 else 400 { 401 fill(input_gate_bias, 17); 402 } 403 fill(cell_to_forget_w, 18); 404 fill(cell_to_output_w, 19); 405 fill(projection_w, 20); 406 fill(projection_bias, 21); 407 408 // Compute forget_gate 409 SimpleTensor<T> fully_connected_forget = reference::fully_connected_layer(input, input_to_forget_w, forget_gate_bias, output_cell_shape); 410 SimpleTensor<T> transposed_weights = reference::transpose(recurrent_to_forget_w); 411 SimpleTensor<T> gemm = reference::gemm(output_state_in, transposed_weights, cell_state_in, 1.f, 0.f); 412 SimpleTensor<T> forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_forget, gemm, data_type, ConvertPolicy::SATURATE); 413 414 if(peephole_opt) 415 { 416 SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO, data_type); 417 forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE); 418 } 419 420 if(use_layer_norm) 421 { 422 SimpleTensor<T> forget_layer_norm_w{ cell_bias_shape, data_type }; 423 fill(forget_layer_norm_w, 23); 424 forget_gate = reference::mean_std_normalization_layer(forget_gate); 425 forget_gate = reference::pixel_wise_multiplication<T, T, T>(forget_gate, forget_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); 426 fill(forget_gate_bias, 7); 427 forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, forget_gate_bias, data_type, ConvertPolicy::SATURATE); 428 } 429 forget_gate = reference::activation_layer(forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); 430 431 // Compute input_gate 432 SimpleTensor<T> input_gate; 433 if(cifg_opt) 434 { 435 SimpleTensor<T> ones{ cell_bias_shape, data_type }; 436 fill_custom_val(ones, 1.f, 0); 437 input_gate = reference::arithmetic_operation<T>(reference::ArithmeticOperation::SUB, ones, forget_gate, data_type, ConvertPolicy::SATURATE); 438 } 439 else 440 { 441 SimpleTensor<T> fully_connected_input = reference::fully_connected_layer(input, input_to_input_w, input_gate_bias, output_cell_shape); 442 transposed_weights = reference::transpose(recurrent_to_input_w); 443 gemm = reference::gemm(output_state_in, transposed_weights, cell_state_in, 1.f, 0.f); 444 input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE); 445 if(peephole_opt) 446 { 447 SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); 448 input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE); 449 } 450 if(use_layer_norm) 451 { 452 SimpleTensor<T> input_layer_norm_w{ cell_bias_shape, data_type }; 453 fill(input_layer_norm_w, 22); 454 input_gate = reference::mean_std_normalization_layer(input_gate); 455 input_gate = reference::pixel_wise_multiplication<T, T, T>(input_gate, input_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); 456 fill(input_gate_bias, 17); 457 input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, input_gate_bias, data_type, ConvertPolicy::SATURATE); 458 } 459 input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); 460 } 461 // Compute cell_state 462 SimpleTensor<T> fully_connected_cell_state = reference::fully_connected_layer(input, input_to_cell_w, cell_bias, output_cell_shape); 463 transposed_weights = reference::transpose(recurrent_to_cell_w); 464 gemm = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f); 465 SimpleTensor<T> pixelwise_mul = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); 466 cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_cell_state, gemm, data_type, ConvertPolicy::SATURATE); 467 if(use_layer_norm) 468 { 469 SimpleTensor<T> cell_layer_norm_w{ cell_bias_shape, data_type }; 470 fill(cell_layer_norm_w, 24); 471 cell_state_out = reference::mean_std_normalization_layer(cell_state_out); 472 cell_state_out = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, cell_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); 473 fill(cell_bias, 8); 474 cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, cell_bias, data_type, ConvertPolicy::SATURATE); 475 } 476 cell_state_out = reference::activation_layer(cell_state_out, info); 477 cell_state_out = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); 478 cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE); 479 480 if(cell_threshold != 0.f) 481 { 482 cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold, -cell_threshold)); 483 } 484 485 // Compute output 486 SimpleTensor<T> fully_connected_output = reference::fully_connected_layer(input, input_to_output_w, output_gate_bias, output_cell_shape); 487 transposed_weights = reference::transpose(recurrent_to_output_w); 488 gemm = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f); 489 output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_output, gemm, data_type, ConvertPolicy::SATURATE); 490 if(peephole_opt) 491 { 492 pixelwise_mul = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); 493 output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, pixelwise_mul, data_type, ConvertPolicy::SATURATE); 494 } 495 if(use_layer_norm) 496 { 497 SimpleTensor<T> output_layer_norm_w{ cell_bias_shape, data_type }; 498 fill(output_layer_norm_w, 25); 499 output = reference::mean_std_normalization_layer(output); 500 output = reference::pixel_wise_multiplication<T, T, T>(output, output_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); 501 fill(output_gate_bias, 9); 502 output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, output_gate_bias, data_type, ConvertPolicy::SATURATE); 503 } 504 output = reference::activation_layer(output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); 505 506 // Compute output state 507 SimpleTensor<T> cell_state_activation = reference::activation_layer(cell_state_out, info); 508 output_state_out = reference::pixel_wise_multiplication<T, T, T>(output, cell_state_activation, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); 509 510 if(projection_opt) 511 { 512 SimpleTensor<T> fully_connected_projection = reference::fully_connected_layer(output_state_out, projection_w, projection_bias, output_cell_shape); 513 if(projection_threshold != 0.f) 514 { 515 output_state_out = reference::activation_layer(fully_connected_projection, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold)); 516 } 517 } 518 std::vector<SimpleTensor<T>> scratch_inputs; 519 if(!cifg_opt) 520 { 521 scratch_inputs.emplace_back(std::move(input_gate)); 522 } 523 scratch_inputs.emplace_back(std::move(cell_state_out)); 524 scratch_inputs.emplace_back(std::move(forget_gate)); 525 scratch_inputs.emplace_back(std::move(output)); 526 scratch = reference::concatenate_layer(scratch_inputs, scratch, Window::DimX); 527 _reference_scratch = std::move(scratch); 528 return output_state_out; 529 } 530 531 TensorType _target{}; 532 TensorType _target_scratch{}; 533 SimpleTensor<T> _reference{}; 534 SimpleTensor<T> _reference_scratch{}; 535 }; 536 } // namespace validation 537 } // namespace test 538 } // namespace arm_compute 539 #endif /* ARM_COMPUTE_TEST_LSTM_LAYER_FIXTURE */ 540