• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
17
18#include <cmath>
19
20#include "tensorflow/lite/delegates/gpu/common/types.h"
21
22#import <XCTest/XCTest.h>
23
24#import <Metal/Metal.h>
25
26@interface MetalSpatialTensorTest : XCTestCase
27@end
28
29@implementation MetalSpatialTensorTest
30- (void)setUp {
31  [super setUp];
32}
33
34using tflite::gpu::half;
35using tflite::gpu::TensorDescriptor;
36using tflite::gpu::TensorStorageType;
37using tflite::gpu::DataType;
38using tflite::gpu::BHWC;
39using tflite::gpu::BHWDC;
40using tflite::gpu::Layout;
41
42namespace {
43template <DataType T>
44absl::Status TensorBHWCTest(const BHWC& shape, const TensorDescriptor& descriptor,
45                            id<MTLDevice> device) {
46  tflite::gpu::Tensor<BHWC, T> tensor_cpu;
47  tensor_cpu.shape = shape;
48  tensor_cpu.data.resize(shape.DimensionsProduct());
49  for (int i = 0; i < tensor_cpu.data.size(); ++i) {
50    // val = [0, 1];
51    const double val = static_cast<double>(i) / static_cast<double>(tensor_cpu.data.size() - 1);
52    double transformed_val = sin(val * 2.0 * M_PI) * 256.0;
53    if (descriptor.GetDataType() == DataType::INT16 ||
54        descriptor.GetDataType() == DataType::UINT16) {
55      transformed_val *= 256.0;
56    }
57    if (descriptor.GetDataType() == DataType::INT32 ||
58        descriptor.GetDataType() == DataType::UINT32) {
59      transformed_val *= 256.0 * 256.0 * 256.0 * 256.0;
60    }
61    if (descriptor.GetDataType() == DataType::FLOAT16) {
62      transformed_val = half(transformed_val);
63    }
64    if (descriptor.GetDataType() == DataType::BOOL) {
65      transformed_val = i % 7;
66    }
67    tensor_cpu.data[i] = transformed_val;
68  }
69  tflite::gpu::Tensor<BHWC, T> tensor_gpu;
70  tensor_gpu.shape = shape;
71  tensor_gpu.data.resize(shape.DimensionsProduct());
72  for (int i = 0; i < tensor_gpu.data.size(); ++i) {
73    tensor_gpu.data[i] = 0;
74  }
75
76  tflite::gpu::metal::MetalSpatialTensor tensor;
77  tflite::gpu::TensorDescriptor descriptor_with_data = descriptor;
78  descriptor_with_data.UploadData(tensor_cpu);
79  RETURN_IF_ERROR(tensor.CreateFromDescriptor(descriptor_with_data, device));
80  tflite::gpu::TensorDescriptor output_descriptor;
81  RETURN_IF_ERROR(tensor.ToDescriptor(&output_descriptor, device));
82  output_descriptor.DownloadData(&tensor_gpu);
83
84  for (int i = 0; i < tensor_gpu.data.size(); ++i) {
85    if (tensor_gpu.data[i] != tensor_cpu.data[i]) {
86      return absl::InternalError("Wrong value at index - " + std::to_string(i) + ". GPU - " +
87                                 std::to_string(tensor_gpu.data[i]) + ", CPU - " +
88                                 std::to_string(tensor_cpu.data[i]));
89    }
90  }
91  return absl::OkStatus();
92}
93
94template absl::Status TensorBHWCTest<DataType::FLOAT32>(const BHWC& shape,
95                                                        const TensorDescriptor& descriptor,
96                                                        id<MTLDevice> device);
97template absl::Status TensorBHWCTest<DataType::INT32>(const BHWC& shape,
98                                                      const TensorDescriptor& descriptor,
99                                                      id<MTLDevice> device);
100
101template absl::Status TensorBHWCTest<DataType::INT16>(const BHWC& shape,
102                                                      const TensorDescriptor& descriptor,
103                                                      id<MTLDevice> device);
104
105template absl::Status TensorBHWCTest<DataType::INT8>(const BHWC& shape,
106                                                     const TensorDescriptor& descriptor,
107                                                     id<MTLDevice> device);
108template absl::Status TensorBHWCTest<DataType::UINT32>(const BHWC& shape,
109                                                       const TensorDescriptor& descriptor,
110                                                       id<MTLDevice> device);
111
112template absl::Status TensorBHWCTest<DataType::UINT16>(const BHWC& shape,
113                                                       const TensorDescriptor& descriptor,
114                                                       id<MTLDevice> device);
115
116template absl::Status TensorBHWCTest<DataType::UINT8>(const BHWC& shape,
117                                                      const TensorDescriptor& descriptor,
118                                                      id<MTLDevice> device);
119
120template absl::Status TensorBHWCTest<DataType::BOOL>(const BHWC& shape,
121                                                     const TensorDescriptor& descriptor,
122                                                     id<MTLDevice> device);
123
124template <DataType T>
125absl::Status TensorBHWDCTest(const BHWDC& shape, const TensorDescriptor& descriptor,
126                             id<MTLDevice> device) {
127  tflite::gpu::Tensor<BHWDC, T> tensor_cpu;
128  tensor_cpu.shape = shape;
129  tensor_cpu.data.resize(shape.DimensionsProduct());
130  for (int i = 0; i < tensor_cpu.data.size(); ++i) {
131    // val = [0, 1];
132    const double val = static_cast<double>(i) / static_cast<double>(tensor_cpu.data.size() - 1);
133    double transformed_val = sin(val * 2.0 * M_PI) * 256.0;
134    if (descriptor.GetDataType() == DataType::INT16 ||
135        descriptor.GetDataType() == DataType::UINT16) {
136      transformed_val *= 256.0;
137    }
138    if (descriptor.GetDataType() == DataType::INT32 ||
139        descriptor.GetDataType() == DataType::UINT32) {
140      transformed_val *= 256.0 * 256.0 * 256.0 * 256.0;
141    }
142    if (descriptor.GetDataType() == DataType::FLOAT16) {
143      transformed_val = half(transformed_val);
144    }
145    if (descriptor.GetDataType() == DataType::BOOL) {
146      transformed_val = i % 7;
147    }
148    tensor_cpu.data[i] = transformed_val;
149  }
150  tflite::gpu::Tensor<BHWDC, T> tensor_gpu;
151  tensor_gpu.shape = shape;
152  tensor_gpu.data.resize(shape.DimensionsProduct());
153  for (int i = 0; i < tensor_gpu.data.size(); ++i) {
154    tensor_gpu.data[i] = 0;
155  }
156
157  tflite::gpu::metal::MetalSpatialTensor tensor;
158  tflite::gpu::TensorDescriptor descriptor_with_data = descriptor;
159  descriptor_with_data.UploadData(tensor_cpu);
160  RETURN_IF_ERROR(tensor.CreateFromDescriptor(descriptor_with_data, device));
161  tflite::gpu::TensorDescriptor output_descriptor;
162  RETURN_IF_ERROR(tensor.ToDescriptor(&output_descriptor, device));
163  output_descriptor.DownloadData(&tensor_gpu);
164
165  for (int i = 0; i < tensor_gpu.data.size(); ++i) {
166    if (tensor_gpu.data[i] != tensor_cpu.data[i]) {
167      return absl::InternalError("Wrong value.");
168    }
169  }
170  return absl::OkStatus();
171}
172
173template absl::Status TensorBHWDCTest<DataType::FLOAT32>(const BHWDC& shape,
174                                                         const TensorDescriptor& descriptor,
175                                                         id<MTLDevice> device);
176template absl::Status TensorBHWDCTest<DataType::INT32>(const BHWDC& shape,
177                                                       const TensorDescriptor& descriptor,
178                                                       id<MTLDevice> device);
179
180template absl::Status TensorBHWDCTest<DataType::INT16>(const BHWDC& shape,
181                                                       const TensorDescriptor& descriptor,
182                                                       id<MTLDevice> device);
183
184template absl::Status TensorBHWDCTest<DataType::INT8>(const BHWDC& shape,
185                                                      const TensorDescriptor& descriptor,
186                                                      id<MTLDevice> device);
187template absl::Status TensorBHWDCTest<DataType::UINT32>(const BHWDC& shape,
188                                                        const TensorDescriptor& descriptor,
189                                                        id<MTLDevice> device);
190
191template absl::Status TensorBHWDCTest<DataType::UINT16>(const BHWDC& shape,
192                                                        const TensorDescriptor& descriptor,
193                                                        id<MTLDevice> device);
194
195template absl::Status TensorBHWDCTest<DataType::UINT8>(const BHWDC& shape,
196                                                       const TensorDescriptor& descriptor,
197                                                       id<MTLDevice> device);
198
199template absl::Status TensorBHWDCTest<DataType::BOOL>(const BHWDC& shape,
200                                                      const TensorDescriptor& descriptor,
201                                                      id<MTLDevice> device);
202
203template <DataType T>
204absl::Status TensorTests(DataType data_type, TensorStorageType storage_type) {
205  id<MTLDevice> device = MTLCreateSystemDefaultDevice();
206  RETURN_IF_ERROR(
207      TensorBHWCTest<T>(BHWC(1, 6, 7, 3), {data_type, storage_type, Layout::HWC}, device));
208  RETURN_IF_ERROR(
209      TensorBHWCTest<T>(BHWC(1, 1, 4, 12), {data_type, storage_type, Layout::HWC}, device));
210  RETURN_IF_ERROR(
211      TensorBHWCTest<T>(BHWC(1, 6, 1, 7), {data_type, storage_type, Layout::HWC}, device));
212
213  // Batch tests
214  RETURN_IF_ERROR(
215      TensorBHWCTest<T>(BHWC(2, 6, 7, 3), {data_type, storage_type, Layout::BHWC}, device));
216  RETURN_IF_ERROR(
217      TensorBHWCTest<T>(BHWC(4, 1, 4, 12), {data_type, storage_type, Layout::BHWC}, device));
218  RETURN_IF_ERROR(
219      TensorBHWCTest<T>(BHWC(7, 6, 1, 7), {data_type, storage_type, Layout::BHWC}, device));
220  RETURN_IF_ERROR(
221      TensorBHWCTest<T>(BHWC(13, 7, 3, 3), {data_type, storage_type, Layout::BHWC}, device));
222
223  // 5D tests with batch = 1
224  RETURN_IF_ERROR(
225      TensorBHWDCTest<T>(BHWDC(1, 6, 7, 4, 3), {data_type, storage_type, Layout::HWDC}, device));
226  RETURN_IF_ERROR(
227      TensorBHWDCTest<T>(BHWDC(1, 1, 4, 3, 12), {data_type, storage_type, Layout::HWDC}, device));
228  RETURN_IF_ERROR(
229      TensorBHWDCTest<T>(BHWDC(1, 6, 1, 7, 7), {data_type, storage_type, Layout::HWDC}, device));
230
231  // 5D tests
232  RETURN_IF_ERROR(
233      TensorBHWDCTest<T>(BHWDC(2, 6, 7, 1, 3), {data_type, storage_type, Layout::BHWDC}, device));
234  RETURN_IF_ERROR(
235      TensorBHWDCTest<T>(BHWDC(4, 1, 4, 2, 12), {data_type, storage_type, Layout::BHWDC}, device));
236  RETURN_IF_ERROR(
237      TensorBHWDCTest<T>(BHWDC(7, 6, 1, 3, 7), {data_type, storage_type, Layout::BHWDC}, device));
238  RETURN_IF_ERROR(
239      TensorBHWDCTest<T>(BHWDC(13, 7, 3, 4, 3), {data_type, storage_type, Layout::BHWDC}, device));
240  return absl::OkStatus();
241}
242
243template absl::Status TensorTests<DataType::FLOAT32>(DataType data_type,
244                                                     TensorStorageType storage_type);
245template absl::Status TensorTests<DataType::INT32>(DataType data_type,
246                                                   TensorStorageType storage_type);
247template absl::Status TensorTests<DataType::INT16>(DataType data_type,
248                                                   TensorStorageType storage_type);
249template absl::Status TensorTests<DataType::INT8>(DataType data_type,
250                                                  TensorStorageType storage_type);
251template absl::Status TensorTests<DataType::UINT32>(DataType data_type,
252                                                    TensorStorageType storage_type);
253template absl::Status TensorTests<DataType::UINT16>(DataType data_type,
254                                                    TensorStorageType storage_type);
255template absl::Status TensorTests<DataType::UINT8>(DataType data_type,
256                                                   TensorStorageType storage_type);
257template absl::Status TensorTests<DataType::BOOL>(DataType data_type,
258                                                  TensorStorageType storage_type);
259
260}  // namespace
261
262- (void)testBufferF32 {
263  auto status = TensorTests<DataType::FLOAT32>(DataType::FLOAT32, TensorStorageType::BUFFER);
264  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
265}
266
267- (void)testBufferF16 {
268  auto status = TensorTests<DataType::FLOAT32>(DataType::FLOAT16, TensorStorageType::BUFFER);
269  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
270}
271
272- (void)testBufferInt32 {
273  auto status = TensorTests<DataType::INT32>(DataType::INT32, TensorStorageType::BUFFER);
274  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
275}
276
277- (void)testBufferInt16 {
278  auto status = TensorTests<DataType::INT16>(DataType::INT16, TensorStorageType::BUFFER);
279  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
280}
281
282- (void)testBufferInt8 {
283  auto status = TensorTests<DataType::INT8>(DataType::INT8, TensorStorageType::BUFFER);
284  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
285}
286
287- (void)testBufferUint32 {
288  auto status = TensorTests<DataType::UINT32>(DataType::UINT32, TensorStorageType::BUFFER);
289  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
290}
291
292- (void)testBufferUint16 {
293  auto status = TensorTests<DataType::UINT16>(DataType::UINT16, TensorStorageType::BUFFER);
294  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
295}
296
297- (void)testBufferUint8 {
298  auto status = TensorTests<DataType::UINT8>(DataType::UINT8, TensorStorageType::BUFFER);
299  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
300}
301
302- (void)testBufferBool {
303  auto status = TensorTests<DataType::BOOL>(DataType::BOOL, TensorStorageType::BUFFER);
304  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
305}
306
307- (void)testTexture2DF32 {
308  auto status = TensorTests<DataType::FLOAT32>(DataType::FLOAT32, TensorStorageType::TEXTURE_2D);
309  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
310}
311
312- (void)testTexture2DF16 {
313  auto status = TensorTests<DataType::FLOAT32>(DataType::FLOAT16, TensorStorageType::TEXTURE_2D);
314  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
315}
316
317- (void)testTexture2DInt32 {
318  auto status = TensorTests<DataType::INT32>(DataType::INT32, TensorStorageType::TEXTURE_2D);
319  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
320}
321
322- (void)testTexture2DInt16 {
323  auto status = TensorTests<DataType::INT16>(DataType::INT16, TensorStorageType::TEXTURE_2D);
324  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
325}
326
327- (void)testTexture2DInt8 {
328  auto status = TensorTests<DataType::INT8>(DataType::INT8, TensorStorageType::TEXTURE_2D);
329  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
330}
331
332- (void)testTexture2DUint32 {
333  auto status = TensorTests<DataType::UINT32>(DataType::UINT32, TensorStorageType::TEXTURE_2D);
334  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
335}
336
337- (void)testTexture2DUint16 {
338  auto status = TensorTests<DataType::UINT16>(DataType::UINT16, TensorStorageType::TEXTURE_2D);
339  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
340}
341
342- (void)testTexture2DUint8 {
343  auto status = TensorTests<DataType::UINT8>(DataType::UINT8, TensorStorageType::TEXTURE_2D);
344  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
345}
346
347- (void)testTexture2DBool {
348  auto status = TensorTests<DataType::BOOL>(DataType::BOOL, TensorStorageType::TEXTURE_2D);
349  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
350}
351
352- (void)testTexture3DF32 {
353  auto status = TensorTests<DataType::FLOAT32>(DataType::FLOAT32, TensorStorageType::TEXTURE_3D);
354  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
355}
356
357- (void)testTexture3DF16 {
358  auto status = TensorTests<DataType::FLOAT32>(DataType::FLOAT16, TensorStorageType::TEXTURE_3D);
359  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
360}
361
362- (void)testTexture3DInt32 {
363  auto status = TensorTests<DataType::INT32>(DataType::INT32, TensorStorageType::TEXTURE_3D);
364  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
365}
366
367- (void)testTexture3DInt16 {
368  auto status = TensorTests<DataType::INT16>(DataType::INT16, TensorStorageType::TEXTURE_3D);
369  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
370}
371
372- (void)testTexture3DInt8 {
373  auto status = TensorTests<DataType::INT8>(DataType::INT8, TensorStorageType::TEXTURE_3D);
374  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
375}
376
377- (void)testTexture3DUint32 {
378  auto status = TensorTests<DataType::UINT32>(DataType::UINT32, TensorStorageType::TEXTURE_3D);
379  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
380}
381
382- (void)testTexture3DUint16 {
383  auto status = TensorTests<DataType::UINT16>(DataType::UINT16, TensorStorageType::TEXTURE_3D);
384  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
385}
386
387- (void)testTexture3DUint8 {
388  auto status = TensorTests<DataType::UINT8>(DataType::UINT8, TensorStorageType::TEXTURE_3D);
389  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
390}
391
392- (void)testTexture3DBool {
393  auto status = TensorTests<DataType::BOOL>(DataType::BOOL, TensorStorageType::TEXTURE_3D);
394  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
395}
396
397- (void)testTexture2DArrayF32 {
398  auto status = TensorTests<DataType::FLOAT32>(DataType::FLOAT32, TensorStorageType::TEXTURE_ARRAY);
399  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
400}
401
402- (void)testTexture2DArrayF16 {
403  auto status = TensorTests<DataType::FLOAT32>(DataType::FLOAT16, TensorStorageType::TEXTURE_ARRAY);
404  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
405}
406
407- (void)testTexture2DArrayInt32 {
408  auto status = TensorTests<DataType::INT32>(DataType::INT32, TensorStorageType::TEXTURE_ARRAY);
409  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
410}
411
412- (void)testTexture2DArrayInt16 {
413  auto status = TensorTests<DataType::INT16>(DataType::INT16, TensorStorageType::TEXTURE_ARRAY);
414  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
415}
416
417- (void)testTexture2DArrayInt8 {
418  auto status = TensorTests<DataType::INT8>(DataType::INT8, TensorStorageType::TEXTURE_ARRAY);
419  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
420}
421
422- (void)testTexture2DArrayUint32 {
423  auto status = TensorTests<DataType::UINT32>(DataType::UINT32, TensorStorageType::TEXTURE_ARRAY);
424  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
425}
426
427- (void)testTexture2DArrayUint16 {
428  auto status = TensorTests<DataType::UINT16>(DataType::UINT16, TensorStorageType::TEXTURE_ARRAY);
429  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
430}
431
432- (void)testTexture2DArrayUint8 {
433  auto status = TensorTests<DataType::UINT8>(DataType::UINT8, TensorStorageType::TEXTURE_ARRAY);
434  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
435}
436
437- (void)testTexture2DArrayBool {
438  auto status = TensorTests<DataType::BOOL>(DataType::BOOL, TensorStorageType::TEXTURE_ARRAY);
439  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
440}
441
442- (void)testTextureBufferF32 {
443  auto status = TensorTests<DataType::FLOAT32>(DataType::FLOAT32, TensorStorageType::IMAGE_BUFFER);
444  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
445}
446
447- (void)testTextureBufferF16 {
448  auto status = TensorTests<DataType::FLOAT32>(DataType::FLOAT16, TensorStorageType::IMAGE_BUFFER);
449  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
450}
451
452- (void)testTextureBufferInt32 {
453  auto status = TensorTests<DataType::INT32>(DataType::INT32, TensorStorageType::IMAGE_BUFFER);
454  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
455}
456
457- (void)testTextureBufferInt16 {
458  auto status = TensorTests<DataType::INT16>(DataType::INT16, TensorStorageType::IMAGE_BUFFER);
459  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
460}
461
462- (void)testTextureBufferInt8 {
463  auto status = TensorTests<DataType::INT8>(DataType::INT8, TensorStorageType::IMAGE_BUFFER);
464  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
465}
466
467- (void)testTextureBufferUint32 {
468  auto status = TensorTests<DataType::UINT32>(DataType::UINT32, TensorStorageType::IMAGE_BUFFER);
469  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
470}
471
472- (void)testTextureBufferUint16 {
473  auto status = TensorTests<DataType::UINT16>(DataType::UINT16, TensorStorageType::IMAGE_BUFFER);
474  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
475}
476
477- (void)testTextureBufferUint8 {
478  auto status = TensorTests<DataType::UINT8>(DataType::UINT8, TensorStorageType::IMAGE_BUFFER);
479  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
480}
481
482- (void)testTextureBufferBool {
483  auto status = TensorTests<DataType::BOOL>(DataType::BOOL, TensorStorageType::IMAGE_BUFFER);
484  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
485}
486
487@end
488