1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/task/weights_conversion.h"
17
18 namespace tflite {
19 namespace gpu {
GetTotalElementsCountForLayout(const WeightsDescription & weight_desc,const OHWI & shape)20 uint GetTotalElementsCountForLayout(const WeightsDescription& weight_desc,
21 const OHWI& shape) {
22 if (weight_desc.layout == WeightsLayout::kOHWIOGroupI4O4 ||
23 weight_desc.layout == WeightsLayout::kOHWIOGroupO4I4 ||
24 weight_desc.layout == WeightsLayout::k2DX4I4YIsHWIAndXIsOOGroupO4 ||
25 weight_desc.layout == WeightsLayout::k2DX4O4YIsHWIAndXIsOOGroupI4) {
26 uint i_aligned = AlignByN(shape.i, 4);
27 uint o_aligned = AlignByN(shape.o, 4 * weight_desc.output_group_size);
28 return i_aligned * o_aligned * shape.h * shape.w;
29 } else if (weight_desc.layout == WeightsLayout::kOICustomSpatialI4O4 ||
30 weight_desc.layout == WeightsLayout::kOICustomSpatialO4I4) {
31 uint i_aligned = AlignByN(shape.i, 4);
32 uint o_aligned = AlignByN(shape.o, 4);
33 return i_aligned * o_aligned * weight_desc.spatial_remap.size();
34 } else {
35 return -1;
36 }
37 }
38
RearrangeWeights(const tflite::gpu::Tensor<OHWI,DataType::FLOAT32> & weights,const WeightsDescription & dst_weight_desc,DataType dst_type,absl::Span<uint8_t> dst)39 void RearrangeWeights(
40 const tflite::gpu::Tensor<OHWI, DataType::FLOAT32>& weights,
41 const WeightsDescription& dst_weight_desc, DataType dst_type,
42 absl::Span<uint8_t> dst) {
43 const uint flt_count =
44 GetTotalElementsCountForLayout(dst_weight_desc, weights.shape);
45 if (dst_weight_desc.layout == WeightsLayout::kOHWIOGroupI4O4) {
46 if (dst_type == DataType::FLOAT32) {
47 float4* f32_ptr = reinterpret_cast<float4*>(dst.data());
48 RearrangeWeightsToOHWIOGroupI4O4(weights,
49 dst_weight_desc.output_group_size,
50 absl::MakeSpan(f32_ptr, flt_count / 4));
51 } else if (dst_type == DataType::FLOAT16) {
52 half4* f16_ptr = reinterpret_cast<half4*>(dst.data());
53 RearrangeWeightsToOHWIOGroupI4O4(weights,
54 dst_weight_desc.output_group_size,
55 absl::MakeSpan(f16_ptr, flt_count / 4));
56 }
57 return;
58 } else if (dst_weight_desc.layout == WeightsLayout::kOHWIOGroupO4I4) {
59 if (dst_type == DataType::FLOAT32) {
60 float4* f32_ptr = reinterpret_cast<float4*>(dst.data());
61 RearrangeWeightsToOHWIOGroupO4I4(weights,
62 dst_weight_desc.output_group_size,
63 absl::MakeSpan(f32_ptr, flt_count / 4));
64 } else if (dst_type == DataType::FLOAT16) {
65 half4* f16_ptr = reinterpret_cast<half4*>(dst.data());
66 RearrangeWeightsToOHWIOGroupO4I4(weights,
67 dst_weight_desc.output_group_size,
68 absl::MakeSpan(f16_ptr, flt_count / 4));
69 }
70 return;
71 } else if (dst_weight_desc.layout == WeightsLayout::kOICustomSpatialI4O4) {
72 if (dst_type == DataType::FLOAT32) {
73 float4* f32_ptr = reinterpret_cast<float4*>(dst.data());
74 RearrangeWeightsToOICustomSpatialI4O4(
75 weights, dst_weight_desc.spatial_remap,
76 absl::MakeSpan(f32_ptr, flt_count / 4));
77 } else if (dst_type == DataType::FLOAT16) {
78 half4* f16_ptr = reinterpret_cast<half4*>(dst.data());
79 RearrangeWeightsToOICustomSpatialI4O4(
80 weights, dst_weight_desc.spatial_remap,
81 absl::MakeSpan(f16_ptr, flt_count / 4));
82 }
83 return;
84 } else if (dst_weight_desc.layout == WeightsLayout::kOICustomSpatialO4I4) {
85 if (dst_type == DataType::FLOAT32) {
86 float4* f32_ptr = reinterpret_cast<float4*>(dst.data());
87 RearrangeWeightsToOICustomSpatialO4I4(
88 weights, dst_weight_desc.spatial_remap,
89 absl::MakeSpan(f32_ptr, flt_count / 4));
90 } else if (dst_type == DataType::FLOAT16) {
91 half4* f16_ptr = reinterpret_cast<half4*>(dst.data());
92 RearrangeWeightsToOICustomSpatialO4I4(
93 weights, dst_weight_desc.spatial_remap,
94 absl::MakeSpan(f16_ptr, flt_count / 4));
95 }
96 return;
97 } else if (dst_weight_desc.layout ==
98 WeightsLayout::k2DX4I4YIsHWIAndXIsOOGroupO4) {
99 if (dst_type == DataType::FLOAT32) {
100 float4* f32_ptr = reinterpret_cast<float4*>(dst.data());
101 RearrangeWeightsToI4HWIOOGroupO4(weights,
102 dst_weight_desc.output_group_size,
103 absl::MakeSpan(f32_ptr, flt_count / 4));
104 } else if (dst_type == DataType::FLOAT16) {
105 half4* f16_ptr = reinterpret_cast<half4*>(dst.data());
106 RearrangeWeightsToI4HWIOOGroupO4(weights,
107 dst_weight_desc.output_group_size,
108 absl::MakeSpan(f16_ptr, flt_count / 4));
109 }
110 return;
111 } else if (dst_weight_desc.layout ==
112 WeightsLayout::k2DX4O4YIsHWIAndXIsOOGroupI4) {
113 if (dst_type == DataType::FLOAT32) {
114 float4* f32_ptr = reinterpret_cast<float4*>(dst.data());
115 RearrangeWeightsToO4HWIOOGroupI4(weights,
116 dst_weight_desc.output_group_size,
117 absl::MakeSpan(f32_ptr, flt_count / 4));
118 } else if (dst_type == DataType::FLOAT16) {
119 half4* f16_ptr = reinterpret_cast<half4*>(dst.data());
120 RearrangeWeightsToO4HWIOOGroupI4(weights,
121 dst_weight_desc.output_group_size,
122 absl::MakeSpan(f16_ptr, flt_count / 4));
123 }
124 return;
125 }
126 }
127
128 } // namespace gpu
129 } // namespace tflite
130