• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/transpose_test_util.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/lite/delegates/gpu/common/operations.h"
22 #include "tensorflow/lite/delegates/gpu/common/status.h"
23 #include "tensorflow/lite/delegates/gpu/common/task/testing_util.h"
24 #include "tensorflow/lite/delegates/gpu/common/tasks/transpose.h"
25 
26 namespace tflite {
27 namespace gpu {
28 namespace {
29 template <DataType T>
TransposeIntTest(TestExecutionEnvironment * env)30 absl::Status TransposeIntTest(TestExecutionEnvironment* env) {
31   tflite::gpu::Tensor<BHWC, T> src;
32   src.shape = BHWC(1, 1, 2, 3);
33   src.data = {1, 2, -3, -4, 3, 6};
34 
35   TransposeAttributes attr;
36   attr.perm = BHWC(0, 1, 3, 2);
37 
38   tflite::gpu::Tensor<BHWC, T> ref_tensor;
39   ref_tensor.shape = BHWC(1, 1, 3, 2);
40   ref_tensor.data = {1, -4, 2, 3, -3, 6};
41 
42   for (auto storage : env->GetSupportedStorages(T)) {
43     OperationDef op_def;
44     op_def.precision = CalculationsPrecision::F32;
45     op_def.src_tensors.push_back({T, storage, Layout::HWC});
46     op_def.dst_tensors.push_back({T, storage, Layout::HWC});
47     TensorDescriptor src_0, dst;
48     src_0 = op_def.src_tensors[0];
49     src_0.UploadData(src);
50     dst.SetBHWCShape(BHWC(1, 1, 3, 2));
51     GPUOperation operation = CreateTranspose(op_def, attr);
52     RETURN_IF_ERROR(env->ExecuteGPUOperation(
53         {&src_0}, {&dst},
54         std::make_unique<GPUOperation>(std::move(operation))));
55     tflite::gpu::Tensor<BHWC, T> dst_tensor;
56     dst.DownloadData(&dst_tensor);
57     if (dst_tensor.data != ref_tensor.data) {
58       return absl::InternalError("not equal");
59     }
60   }
61   return absl::OkStatus();
62 }
63 
64 template absl::Status TransposeIntTest<DataType::INT32>(
65     TestExecutionEnvironment* env);
66 template absl::Status TransposeIntTest<DataType::INT16>(
67     TestExecutionEnvironment* env);
68 template absl::Status TransposeIntTest<DataType::INT8>(
69     TestExecutionEnvironment* env);
70 
71 template <DataType T>
TransposeUintTest(TestExecutionEnvironment * env)72 absl::Status TransposeUintTest(TestExecutionEnvironment* env) {
73   tflite::gpu::Tensor<BHWC, T> src;
74   src.shape = BHWC(1, 1, 2, 3);
75   src.data = {1, 2, 3, 4, 5, 6};
76 
77   TransposeAttributes attr;
78   attr.perm = BHWC(0, 1, 3, 2);
79 
80   tflite::gpu::Tensor<BHWC, T> ref_tensor;
81   ref_tensor.shape = BHWC(1, 1, 3, 2);
82   ref_tensor.data = {1, 4, 2, 5, 3, 6};
83 
84   for (auto storage : env->GetSupportedStorages(T)) {
85     OperationDef op_def;
86     op_def.precision = CalculationsPrecision::F32;
87     op_def.src_tensors.push_back({T, storage, Layout::HWC});
88     op_def.dst_tensors.push_back({T, storage, Layout::HWC});
89     TensorDescriptor src_0, dst;
90     src_0 = op_def.src_tensors[0];
91     src_0.UploadData(src);
92     dst.SetBHWCShape(BHWC(1, 1, 3, 2));
93     GPUOperation operation = CreateTranspose(op_def, attr);
94     RETURN_IF_ERROR(env->ExecuteGPUOperation(
95         {&src_0}, {&dst},
96         std::make_unique<GPUOperation>(std::move(operation))));
97     tflite::gpu::Tensor<BHWC, T> dst_tensor;
98     dst.DownloadData(&dst_tensor);
99     if (dst_tensor.data != ref_tensor.data) {
100       return absl::InternalError("not equal");
101     }
102   }
103   return absl::OkStatus();
104 }
105 
106 template absl::Status TransposeUintTest<DataType::UINT32>(
107     TestExecutionEnvironment* env);
108 template absl::Status TransposeUintTest<DataType::UINT16>(
109     TestExecutionEnvironment* env);
110 template absl::Status TransposeUintTest<DataType::UINT8>(
111     TestExecutionEnvironment* env);
112 
113 }  // namespace
114 
TransposeTest(TestExecutionEnvironment * env)115 absl::Status TransposeTest(TestExecutionEnvironment* env) {
116   TensorFloat32 src_tensor;
117   src_tensor.shape = BHWC(1, 1, 2, 3);
118   src_tensor.data = {half(1.0f), half(2.0f), half(3.0f),
119                      half(4.0f), half(5.0f), half(6.0f)};
120 
121   TransposeAttributes attr;
122   attr.perm = BHWC(0, 1, 3, 2);
123 
124   for (auto precision : env->GetSupportedPrecisions()) {
125     auto data_type = DeduceDataTypeFromPrecision(precision);
126     for (auto storage : env->GetSupportedStorages(data_type)) {
127       OperationDef op_def;
128       op_def.precision = precision;
129       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
130       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
131       TensorFloat32 dst_tensor;
132       GPUOperation operation = CreateTranspose(op_def, attr);
133       RETURN_IF_ERROR(env->ExecuteGPUOperation(
134           src_tensor, std::make_unique<GPUOperation>(std::move(operation)),
135           BHWC(1, 1, 3, 2), &dst_tensor));
136       RETURN_IF_ERROR(PointWiseNear({half(1.0f), half(4.0f), half(2.0f),
137                                      half(5.0f), half(3.0f), half(6.0f)},
138                                     dst_tensor.data, 0.0f));
139     }
140   }
141 
142   RETURN_IF_ERROR(TransposeIntTest<DataType::INT32>(env));
143   RETURN_IF_ERROR(TransposeIntTest<DataType::INT16>(env));
144   RETURN_IF_ERROR(TransposeIntTest<DataType::INT8>(env));
145   RETURN_IF_ERROR(TransposeUintTest<DataType::UINT32>(env));
146   RETURN_IF_ERROR(TransposeUintTest<DataType::UINT16>(env));
147   RETURN_IF_ERROR(TransposeUintTest<DataType::UINT8>(env));
148   return absl::OkStatus();
149 }
150 
151 }  // namespace gpu
152 }  // namespace tflite
153