1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
17 
18 #include <cstdint>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/substitute.h"
24 #include "tensorflow/lite/delegates/gpu/common/shape.h"
25 #include "tensorflow/lite/delegates/gpu/common/util.h"
26 
27 namespace tflite {
28 namespace gpu {
29 namespace {
GetReadImageFromDataType(DataType data_type)30 std::string GetReadImageFromDataType(DataType data_type) {
31   if (data_type == DataType::FLOAT32) {
32     return "read_imagef";
33   } else if (data_type == DataType::FLOAT16) {
34     return "read_imageh";
35   } else {
36     return "error";
37   }
38 }
39 
GetWriteImageFromDataType(DataType data_type)40 std::string GetWriteImageFromDataType(DataType data_type) {
41   if (data_type == DataType::FLOAT32) {
42     return "write_imagef";
43   } else if (data_type == DataType::FLOAT16) {
44     return "write_imageh";
45   } else {
46     return "error";
47   }
48 }
49 
AddressModeToCLSampler(AddressMode address_mode)50 std::string AddressModeToCLSampler(AddressMode address_mode) {
51   switch (address_mode) {
52     case AddressMode::kDontCare:
53       return "smp_none";
54     case AddressMode::kZero:
55       return "smp_zero";
56   }
57 }
58 
59 }  // namespace
60 
ToString(TensorStorageType type)61 std::string ToString(TensorStorageType type) {
62   switch (type) {
63     case TensorStorageType::UNKNOWN:
64       return "TensorStorageType::UNKNOWN";
65     case TensorStorageType::BUFFER:
66       return "TensorStorageType::BUFFER";
67     case TensorStorageType::TEXTURE_ARRAY:
68       return "TensorStorageType::TEXTURE_ARRAY";
69     case TensorStorageType::TEXTURE_2D:
70       return "TensorStorageType::TEXTURE_2D";
71     case TensorStorageType::TEXTURE_3D:
72       return "TensorStorageType::TEXTURE_3D";
73     case TensorStorageType::SINGLE_TEXTURE_2D:
74       return "TensorStorageType::SINGLE_TEXTURE_2D";
75     case TensorStorageType::IMAGE_BUFFER:
76       return "TensorStorageType::IMAGE_BUFFER";
77   }
78 }
79 
TensorDescriptor(TensorDescriptor && desc)80 TensorDescriptor::TensorDescriptor(TensorDescriptor&& desc)
81     : GPUObjectDescriptor(std::move(desc)),
82       data_type(desc.data_type),
83       storage_type(desc.storage_type),
84       layout(desc.layout),
85       shape(desc.shape),
86       data(std::move(desc.data)) {}
operator =(TensorDescriptor && desc)87 TensorDescriptor& TensorDescriptor::operator=(TensorDescriptor&& desc) {
88   if (this != &desc) {
89     std::swap(data_type, desc.data_type);
90     std::swap(storage_type, desc.storage_type);
91     std::swap(layout, desc.layout);
92     std::swap(shape, desc.shape);
93     data = std::move(desc.data);
94     GPUObjectDescriptor::operator=(std::move(desc));
95   }
96   return *this;
97 }
98 
GetGPUResources(const GpuInfo & gpu_info) const99 GPUResources TensorDescriptor::GetGPUResources(const GpuInfo& gpu_info) const {
100   GPUResources resources;
101   resources.ints.push_back("slice_stride");
102   if (HasAxis(Axis::WIDTH)) {
103     resources.ints.push_back("width");
104   }
105   if (HasAxis(Axis::HEIGHT)) {
106     resources.ints.push_back("height");
107   }
108   if (HasAxis(Axis::CHANNELS)) {
109     resources.ints.push_back("slices");
110     resources.ints.push_back("channels");
111   }
112   if (HasAxis(Axis::BATCH)) {
113     resources.ints.push_back("batch");
114   }
115   if (HasAxis(Axis::DEPTH)) {
116     resources.ints.push_back("depth");
117   }
118   if (storage_type == TensorStorageType::BUFFER) {
119     GPUBufferDescriptor desc;
120     desc.data_type = data_type;
121     desc.access_type = access_type_;
122     desc.element_size = 4;
123     auto it1 = state_vars_.find("ElementsX2");
124     if (it1 != state_vars_.end() && it1->second == "true") {
125       desc.element_size = 8;
126     }
127     auto it2 = state_vars_.find("ElementsX4");
128     if (it2 != state_vars_.end() && it2->second == "true") {
129       desc.element_size = 16;
130     }
131     resources.buffers.push_back({"buffer", desc});
132   } else if (storage_type == TensorStorageType::SINGLE_TEXTURE_2D ||
133              storage_type == TensorStorageType::TEXTURE_2D) {
134     GPUImage2DDescriptor desc;
135     desc.data_type = data_type;
136     desc.normalized = false;
137     desc.access_type = access_type_;
138     resources.images2d.push_back({"image2d", desc});
139   } else if (storage_type == TensorStorageType::TEXTURE_ARRAY) {
140     GPUImage2DArrayDescriptor desc;
141     desc.data_type = data_type;
142     desc.access_type = access_type_;
143     resources.image2d_arrays.push_back({"image2d_array", desc});
144   } else if (storage_type == TensorStorageType::TEXTURE_3D) {
145     GPUImage3DDescriptor desc;
146     desc.data_type = data_type;
147     desc.access_type = access_type_;
148     resources.images3d.push_back({"image3d", desc});
149   } else if (storage_type == TensorStorageType::IMAGE_BUFFER) {
150     if (access_type_ == AccessType::READ) {
151       GPUImageBufferDescriptor desc;
152       desc.data_type = data_type;
153       desc.access_type = access_type_;
154       resources.image_buffers.push_back({"image_buffer", desc});
155     } else {
156       GPUBufferDescriptor desc;
157       desc.data_type = data_type;
158       desc.access_type = access_type_;
159       desc.element_size = 4;
160       resources.buffers.push_back({"buffer", desc});
161     }
162   }
163   return resources;
164 }
165 
PerformSelector(const GpuInfo & gpu_info,const std::string & selector,const std::vector<std::string> & args,const std::vector<std::string> & template_args,std::string * result) const166 absl::Status TensorDescriptor::PerformSelector(
167     const GpuInfo& gpu_info, const std::string& selector,
168     const std::vector<std::string>& args,
169     const std::vector<std::string>& template_args, std::string* result) const {
170   if (selector == "Width") {
171     *result = "width";
172     return absl::OkStatus();
173   } else if (selector == "Height") {
174     *result = "height";
175     return absl::OkStatus();
176   } else if (selector == "Slices") {
177     *result = "slices";
178     return absl::OkStatus();
179   } else if (selector == "SliceStride") {
180     *result = "slice_stride";
181     return absl::OkStatus();
182   } else if (selector == "Channels") {
183     *result = "channels";
184     return absl::OkStatus();
185   } else if (selector == "Batch") {
186     if (HasAxis(Axis::BATCH)) {
187       *result = "batch";
188     } else {
189       *result = "1";
190     }
191     return absl::OkStatus();
192   } else if (selector == "Depth") {
193     *result = "depth";
194     return absl::OkStatus();
195   } else if (selector == "SetBatchRef") {
196     if (args.size() != 1) {
197       return absl::InvalidArgumentError(
198           "Unsupported arguments in SetBatchRef selector");
199     }
200     state_vars_["batch_id"] = args[0];
201     *result = "";
202     return absl::OkStatus();
203   } else if (selector == "Read") {
204     return PerformReadSelector(gpu_info, args, template_args, result);
205   } else if (selector == "Write") {
206     return PerformWriteSelector(gpu_info, args, result);
207   } else if (selector == "WriteLinear") {
208     return PerformWriteLinearSelector(gpu_info, args, result);
209   } else if (selector == "Write2D") {
210     return PerformWrite2DSelector(gpu_info, args, result);
211   } else if (selector == "GetAddress") {
212     return PerformGetAddressSelector(args, result);
213   } else if (selector == "GetPtrWithSliceOffset") {
214     return PerformGetPtrWithSliceOffsetSelector(args, result);
215   } else if (selector == "GetWHOffset") {
216     return PerformGetWHOffsetSelector(args, result);
217   } else if (selector == "GetHandle") {
218     return PerformGetHandleSelector(args, result);
219   } else {
220     return absl::NotFoundError(absl::StrCat(
221         "TensorDescriptor don't have selector with name - ", selector));
222   }
223 }
224 
PerformReadSelector(const GpuInfo & gpu_info,const std::vector<std::string> & args,const std::vector<std::string> & template_args,std::string * result) const225 absl::Status TensorDescriptor::PerformReadSelector(
226     const GpuInfo& gpu_info, const std::vector<std::string>& args,
227     const std::vector<std::string>& template_args, std::string* result) const {
228   DataType read_as_type = data_type;
229   if (!template_args.empty()) {
230     if (template_args.size() != 1) {
231       return absl::NotFoundError(
232           "Unrecognized Read selector template arguments.");
233     } else {
234       RETURN_IF_ERROR(
235           GetDataTypeFromTemplateArgs(template_args[0], &read_as_type));
236     }
237   }
238   if (args.size() == 1) {  // function overload for 1D linear types.
239     if (storage_type == TensorStorageType::BUFFER ||
240         storage_type == TensorStorageType::IMAGE_BUFFER) {
241       *result = Read(gpu_info, read_as_type, {args[0]});
242       return absl::OkStatus();
243     } else {
244       return absl::InvalidArgumentError(
245           "Read selector with single argument can be used only with linear "
246           "storage types(BUFFER or IMAGE_BUFFER)");
247     }
248   }
249   std::string xc;
250   std::string yc;
251   std::string zc;
252   std::string sc;
253   std::string bc;
254   bool parsed = ParseCoordsFromArgs(args, 0, &xc, &yc, &zc, &sc, &bc);
255   if (args.size() < 2 || !parsed) {
256     return absl::NotFoundError("Unrecognized Read selector");
257   }
258 
259   *result = Read(gpu_info, read_as_type, GetPhysicalCoords(xc, yc, zc, sc, bc));
260   return absl::OkStatus();
261 }
262 
GetLinkingContextFromWriteSelector(const std::vector<std::string> & args,std::string * value_name,std::string * x_coord,std::string * y_coord,std::string * s_coord) const263 absl::Status TensorDescriptor::GetLinkingContextFromWriteSelector(
264     const std::vector<std::string>& args, std::string* value_name,
265     std::string* x_coord, std::string* y_coord, std::string* s_coord) const {
266   std::string xc;
267   std::string yc;
268   std::string zc;
269   std::string sc;
270   std::string bc;
271   bool parsed = ParseCoordsFromArgs(args, 1, &xc, &yc, &zc, &sc, &bc);
272   if (args.size() < 2 || !parsed) {
273     return absl::NotFoundError("Unrecognized Write selector");
274   }
275   *value_name = args[0];
276   if (HasAxis(Axis::BATCH) && !IsBatchedWidth()) {
277     *x_coord = absl::StrCat("((", xc, ") * batch + (", bc, "))");
278   } else {
279     *x_coord = absl::StrCat("(", xc, ")");
280   }
281   *y_coord = absl::StrCat("(", yc, ")");
282   *s_coord = absl::StrCat("(", sc, ")");
283   return absl::OkStatus();
284 }
285 
PerformWriteSelector(const GpuInfo & gpu_info,const std::vector<std::string> & args,std::string * result) const286 absl::Status TensorDescriptor::PerformWriteSelector(
287     const GpuInfo& gpu_info, const std::vector<std::string>& args,
288     std::string* result) const {
289   std::string xc;
290   std::string yc;
291   std::string zc;
292   std::string sc;
293   std::string bc;
294   bool parsed = ParseCoordsFromArgs(args, 1, &xc, &yc, &zc, &sc, &bc);
295   if (args.size() < 2 || !parsed) {
296     return absl::NotFoundError("Unrecognized Write selector");
297   }
298   *result = Write(gpu_info, args[0], GetPhysicalCoords(xc, yc, zc, sc, bc));
299   return absl::OkStatus();
300 }
301 
PerformWriteLinearSelector(const GpuInfo & gpu_info,const std::vector<std::string> & args,std::string * result) const302 absl::Status TensorDescriptor::PerformWriteLinearSelector(
303     const GpuInfo& gpu_info, const std::vector<std::string>& args,
304     std::string* result) const {
305   if (storage_type != TensorStorageType::BUFFER &&
306       storage_type != TensorStorageType::IMAGE_BUFFER) {
307     return absl::InvalidArgumentError(
308         "WriteLinear selector can be used only with linear "
309         "storages(BUFFER/IMAGE_BUFFER)");
310   }
311   if (args.size() != 2) {
312     return absl::NotFoundError("Unrecognized WriteLinear selector");
313   }
314   *result = Write(gpu_info, args[0], {args[1]});
315   return absl::OkStatus();
316 }
317 
PerformWrite2DSelector(const GpuInfo & gpu_info,const std::vector<std::string> & args,std::string * result) const318 absl::Status TensorDescriptor::PerformWrite2DSelector(
319     const GpuInfo& gpu_info, const std::vector<std::string>& args,
320     std::string* result) const {
321   if (storage_type != TensorStorageType::TEXTURE_2D) {
322     return absl::InvalidArgumentError(
323         "Write2D selector can be used only with 2d "
324         "storages(TEXTURE_2D)");
325   }
326   if (args.size() != 3) {
327     return absl::NotFoundError("Unrecognized Write2D selector");
328   }
329   *result = Write(gpu_info, args[0], {args[1], args[2]});
330   return absl::OkStatus();
331 }
332 
Read(const GpuInfo & gpu_info,DataType read_as_type,const std::vector<std::string> & coords) const333 std::string TensorDescriptor::Read(
334     const GpuInfo& gpu_info, DataType read_as_type,
335     const std::vector<std::string>& coords) const {
336   const std::string read_as =
337       read_as_type == DataType::FLOAT16 ? "read_imageh" : "read_imagef";
338   const bool need_conversion = read_as_type != data_type;
339   const std::string metal_type =
340       read_as_type == DataType::FLOAT32 ? "float4" : "half4";
341   switch (storage_type) {
342     case TensorStorageType::BUFFER:
343       if (gpu_info.IsGlsl()) {
344         if (data_type == DataType::FLOAT16) {
345           return absl::StrCat("vec4(unpackHalf2x16(buffer[", coords[0],
346                               "].x), unpackHalf2x16(buffer[", coords[0],
347                               "].y))");
348         } else {
349           return absl::StrCat("buffer[", coords[0], "]");
350         }
351       }
352       if (read_as_type == data_type) {
353         return absl::StrCat("buffer[", coords[0], "]");
354       } else {
355         std::string conversion;
356         if (gpu_info.IsApiMetal()) {
357           conversion = metal_type;
358         } else if (gpu_info.IsApiOpenCl()) {
359           if (read_as_type == DataType::FLOAT16) {
360             conversion = "convert_half4";
361           } else if (read_as_type == DataType::FLOAT32) {
362             conversion = "convert_float4";
363           }
364         }
365         return absl::StrCat(conversion, "(buffer[", coords[0], "])");
366       }
367     case TensorStorageType::TEXTURE_2D:
368     case TensorStorageType::SINGLE_TEXTURE_2D:
369       if (gpu_info.IsApiOpenCl()) {
370         return absl::Substitute("$0(image2d, $1, (int2)($2, $3))", read_as,
371                                 AddressModeToCLSampler(AddressModeFromState()),
372                                 coords[0], coords[1]);
373       } else if (gpu_info.IsApiMetal()) {
374         std::string result = absl::Substitute("image2d.read(ushort2($0, $1))",
375                                               coords[0], coords[1]);
376         if (need_conversion) {
377           result = metal_type + "(" + result + ")";
378         }
379         return result;
380       } else if (gpu_info.IsGlsl()) {
381         return "texelFetch(image2d, ivec2(" + coords[0] + ", " + coords[1] +
382                "), 0)";
383       } else {
384         return "";
385       }
386     case TensorStorageType::TEXTURE_3D:
387       if (gpu_info.IsApiOpenCl()) {
388         return absl::Substitute("$0(image3d, $1, (int4)($2, $3, $4, 0))",
389                                 read_as,
390                                 AddressModeToCLSampler(AddressModeFromState()),
391                                 coords[0], coords[1], coords[2]);
392       } else if (gpu_info.IsApiMetal()) {
393         std::string result =
394             absl::Substitute("image3d.read(ushort3($0, $1, $2))", coords[0],
395                              coords[1], coords[2]);
396         if (need_conversion) {
397           result = metal_type + "(" + result + ")";
398         }
399         return result;
400       } else if (gpu_info.IsGlsl()) {
401         return "texelFetch(image3d, ivec3(" + coords[0] + ", " + coords[1] +
402                ", " + coords[2] + "), 0)";
403       } else {
404         return "";
405       }
406     case TensorStorageType::TEXTURE_ARRAY:
407       if (gpu_info.IsApiOpenCl()) {
408         return absl::Substitute("$0(image2d_array, $1, (int4)($2, $3, $4, 0))",
409                                 read_as,
410                                 AddressModeToCLSampler(AddressModeFromState()),
411                                 coords[0], coords[1], coords[2]);
412       } else if (gpu_info.IsApiMetal()) {
413         std::string result =
414             absl::Substitute("image2d_array.read(ushort2($0, $1), $2)",
415                              coords[0], coords[1], coords[2]);
416         if (need_conversion) {
417           result = metal_type + "(" + result + ")";
418         }
419         return result;
420       } else if (gpu_info.IsGlsl()) {
421         return "texelFetch(image2d_array, ivec3(" + coords[0] + ", " +
422                coords[1] + ", " + coords[2] + "), 0)";
423       } else {
424         return "";
425       }
426     case TensorStorageType::IMAGE_BUFFER:
427       if (gpu_info.IsApiOpenCl()) {
428         return absl::StrCat(read_as, "(image_buffer, ", coords[0], ")");
429       } else if (gpu_info.IsApiMetal()) {
430         std::string result =
431             absl::Substitute("image_buffer.read(uint($0))", coords[0]);
432         if (need_conversion) {
433           result = metal_type + "(" + result + ")";
434         }
435         return result;
436       } else if (gpu_info.IsGlsl()) {
437         return "texelFetch(image_buffer, " + coords[0] + ")";
438       } else {
439         return "";
440       }
441     case TensorStorageType::UNKNOWN:
442       return "";
443   }
444 }
445 
Write(const GpuInfo & gpu_info,const std::string & var_name,const std::vector<std::string> & coords) const446 std::string TensorDescriptor::Write(
447     const GpuInfo& gpu_info, const std::string& var_name,
448     const std::vector<std::string>& coords) const {
449   switch (storage_type) {
450     case TensorStorageType::BUFFER:
451     case TensorStorageType::IMAGE_BUFFER:
452       if (gpu_info.IsGlsl()) {
453         if (data_type == DataType::FLOAT16) {
454           return absl::StrCat("buffer[", coords[0], "] = uvec2(packHalf2x16(",
455                               var_name, ".xy), packHalf2x16(", var_name,
456                               ".zw))");
457         } else {
458           return absl::StrCat("buffer[", coords[0], "] = ", var_name);
459         }
460       }
461       return absl::StrCat("buffer[", coords[0], "] = ", var_name);
462     case TensorStorageType::SINGLE_TEXTURE_2D:
463     case TensorStorageType::TEXTURE_2D:
464       if (gpu_info.IsApiOpenCl()) {
465         return absl::Substitute("$0(image2d, (int2)($1, $2), $3)",
466                                 GetWriteImageFromDataType(data_type), coords[0],
467                                 coords[1], var_name);
468       } else if (gpu_info.IsApiMetal()) {
469         return absl::Substitute("image2d.write($0, ushort2($1, $2))", var_name,
470                                 coords[0], coords[1]);
471       } else if (gpu_info.IsGlsl()) {
472         return absl::Substitute("imageStore(image2d, ivec2($0, $1), $2)",
473                                 coords[0], coords[1], var_name);
474       } else {
475         return "";
476       }
477     case TensorStorageType::TEXTURE_3D:
478       if (gpu_info.IsApiOpenCl()) {
479         return absl::Substitute("$0(image3d, (int4)($1, $2, $3, 0), $4)",
480                                 GetWriteImageFromDataType(data_type), coords[0],
481                                 coords[1], coords[2], var_name);
482       } else if (gpu_info.IsApiMetal()) {
483         return absl::Substitute("image3d.write($0, ushort3($1, $2, $3))",
484                                 var_name, coords[0], coords[1], coords[2]);
485       } else if (gpu_info.IsGlsl()) {
486         return absl::Substitute("imageStore(image3d, ivec3($0, $1, $2), $3)",
487                                 coords[0], coords[1], coords[2], var_name);
488       } else {
489         return "";
490       }
491     case TensorStorageType::TEXTURE_ARRAY:
492       if (gpu_info.IsApiOpenCl()) {
493         return absl::Substitute("$0(image2d_array, (int4)($1, $2, $3, 0), $4)",
494                                 GetWriteImageFromDataType(data_type), coords[0],
495                                 coords[1], coords[2], var_name);
496       } else if (gpu_info.IsApiMetal()) {
497         return absl::Substitute("image2d_array.write($0, ushort2($1, $2), $3)",
498                                 var_name, coords[0], coords[1], coords[2]);
499       } else if (gpu_info.IsGlsl()) {
500         return absl::Substitute(
501             "imageStore(image2d_array, ivec3($0, $1, $2), $3)", coords[0],
502             coords[1], coords[2], var_name);
503       } else {
504         return "";
505       }
506     case TensorStorageType::UNKNOWN:
507       return "";
508   }
509 }
510 
PerformGetAddressSelector(const std::vector<std::string> & args,std::string * result) const511 absl::Status TensorDescriptor::PerformGetAddressSelector(
512     const std::vector<std::string>& args, std::string* result) const {
513   std::string xc;
514   std::string yc;
515   std::string zc;
516   std::string sc;
517   std::string bc;
518   bool parsed = ParseCoordsFromArgs(args, 1, &xc, &yc, &zc, &sc, &bc);
519   if (args.size() < 3 || !parsed) {
520     return absl::NotFoundError("Unrecognized GetAddress selector");
521   }
522 
523   *result = DeclareAddress(args[0],
524                            GetGlobalAddressNoDeclaration(xc, yc, zc, sc, bc));
525   return absl::OkStatus();
526 }
527 
PerformGetPtrWithSliceOffsetSelector(const std::vector<std::string> & args,std::string * result) const528 absl::Status TensorDescriptor::PerformGetPtrWithSliceOffsetSelector(
529     const std::vector<std::string>& args, std::string* result) const {
530   if (storage_type != TensorStorageType::BUFFER) {
531     return absl::InvalidArgumentError(
532         "GetPtrWithSliceOffset selector can be used only with BUFFER");
533   }
534   if (args.size() != 1) {
535     return absl::NotFoundError(absl::StrCat(
536         "GetPtrWithSliceOffset require one argument(slice coordinate), but ",
537         args.size(), " was passed"));
538   }
539   *result = absl::StrCat("buffer + ", args[0], " * slice_stride");
540   return absl::OkStatus();
541 }
542 
PerformGetWHOffsetSelector(const std::vector<std::string> & args,std::string * result) const543 absl::Status TensorDescriptor::PerformGetWHOffsetSelector(
544     const std::vector<std::string>& args, std::string* result) const {
545   if (storage_type != TensorStorageType::BUFFER &&
546       storage_type != TensorStorageType::IMAGE_BUFFER) {
547     return absl::InvalidArgumentError(
548         "GetWHOffset selector can be used only with BUFFER/IMAGE_BUFFER");
549   }
550   if (args.size() != 2) {
551     return absl::NotFoundError(absl::StrCat(
552         "GetWHOffset require two arguments(X and Y coordinates), but ",
553         args.size(), " was passed"));
554   }
555   if (HasAxis(Axis::BATCH) && !IsBatchedWidth()) {
556     auto it = state_vars_.find("batch_id");
557     std::string batch_id;
558     if (it == state_vars_.end()) {
559       return absl::NotFoundError(
560           "Not found batch_id. Should be setted up by SetBatchRef(). method");
561     } else {
562       batch_id = it->second;
563     }
564     *result = absl::StrCat("((", args[1], ") * width + (", args[0],
565                            ")) * batch + (", batch_id, ")");
566   } else {
567     *result = absl::StrCat("(", args[1], ") * width + (", args[0], ")");
568   }
569   return absl::OkStatus();
570 }
571 
PerformGetHandleSelector(const std::vector<std::string> & args,std::string * result) const572 absl::Status TensorDescriptor::PerformGetHandleSelector(
573     const std::vector<std::string>& args, std::string* result) const {
574   if (!args.empty()) {
575     return absl::NotFoundError(
576         absl::StrCat("GetHandle does not require arguments, but ", args.size(),
577                      " was passed"));
578   }
579   switch (storage_type) {
580     case TensorStorageType::BUFFER:
581       *result = "buffer";
582       return absl::OkStatus();
583     case TensorStorageType::IMAGE_BUFFER:
584       if (access_type_ == AccessType::READ) {
585         *result = "image_buffer";
586       } else {
587         *result = "buffer";
588       }
589       return absl::OkStatus();
590     case TensorStorageType::TEXTURE_2D:
591     case TensorStorageType::SINGLE_TEXTURE_2D:
592       *result = "image2d";
593       return absl::OkStatus();
594     case TensorStorageType::TEXTURE_ARRAY:
595       *result = "image2d_array";
596       return absl::OkStatus();
597     case TensorStorageType::TEXTURE_3D:
598       *result = "image3d";
599       return absl::OkStatus();
600     case TensorStorageType::UNKNOWN:
601       return absl::UnavailableError("Unknown type");
602   }
603 }
604 
DeclareAddress(const std::string & var_name,const std::string & address) const605 std::string TensorDescriptor::DeclareAddress(const std::string& var_name,
606                                              const std::string& address) const {
607   return absl::StrCat(StorageTypeToAddressType(), " ", var_name, " = ", address,
608                       ";");
609 }
610 
StorageTypeToAddressType() const611 std::string TensorDescriptor::StorageTypeToAddressType() const {
612   switch (storage_type) {
613     case TensorStorageType::BUFFER:
614     case TensorStorageType::IMAGE_BUFFER:
615       return "int";
616     case TensorStorageType::TEXTURE_2D:
617     case TensorStorageType::SINGLE_TEXTURE_2D:
618       return "int2";
619     case TensorStorageType::TEXTURE_ARRAY:
620     case TensorStorageType::TEXTURE_3D:
621       return "int4";
622     case TensorStorageType::UNKNOWN:
623       return "";
624   }
625 }
626 
GetPhysicalCoordsWHS(const std::string & x,const std::string & y,const std::string & s) const627 std::vector<std::string> TensorDescriptor::GetPhysicalCoordsWHS(
628     const std::string& x, const std::string& y, const std::string& s) const {
629   switch (storage_type) {
630     case TensorStorageType::BUFFER:
631     case TensorStorageType::IMAGE_BUFFER:
632       return {
633           absl::Substitute("((($2) * height + ($1)) * width + ($0))", x, y, s)};
634     case TensorStorageType::TEXTURE_2D:
635       return {absl::Substitute("($0)", x),
636               absl::Substitute("(($0) * slices + ($1))", y, s)};
637     case TensorStorageType::SINGLE_TEXTURE_2D:
638       return {absl::Substitute("($0)", x), absl::Substitute("($0)", y)};
639     case TensorStorageType::TEXTURE_ARRAY:
640     case TensorStorageType::TEXTURE_3D:
641       return {absl::Substitute("($0)", x), absl::Substitute("($0)", y),
642               absl::Substitute("($0)", s)};
643     case TensorStorageType::UNKNOWN:
644       return {""};
645     default:
646       return {""};
647   }
648 }
649 
GetPhysicalCoordsWHSB(const std::string & x,const std::string & y,const std::string & s,const std::string & b) const650 std::vector<std::string> TensorDescriptor::GetPhysicalCoordsWHSB(
651     const std::string& x, const std::string& y, const std::string& s,
652     const std::string& b) const {
653   switch (storage_type) {
654     case TensorStorageType::BUFFER:
655     case TensorStorageType::IMAGE_BUFFER:
656       return {absl::Substitute(
657           "(((($3) * height + $2) * width + ($1)) * batch + ($0))", b, x, y,
658           s)};
659     case TensorStorageType::TEXTURE_2D:
660       return {absl::Substitute("(($0) * batch + ($1))", x, b),
661               absl::Substitute("(($0) * slices + ($1))", y, s)};
662     case TensorStorageType::SINGLE_TEXTURE_2D:
663       return {absl::Substitute("(($0) * batch + ($1))", x, b),
664               absl::Substitute("($0)", y)};
665     case TensorStorageType::TEXTURE_ARRAY:
666     case TensorStorageType::TEXTURE_3D:
667       return {absl::Substitute("(($0) * batch + ($1))", x, b),
668               absl::Substitute("($0)", y), absl::Substitute("($0)", s)};
669     case TensorStorageType::UNKNOWN:
670       return {""};
671     default:
672       return {""};
673   }
674 }
675 
GetPhysicalCoordsWHDS(const std::string & x,const std::string & y,const std::string & z,const std::string & s) const676 std::vector<std::string> TensorDescriptor::GetPhysicalCoordsWHDS(
677     const std::string& x, const std::string& y, const std::string& z,
678     const std::string& s) const {
679   switch (storage_type) {
680     case TensorStorageType::BUFFER:
681     case TensorStorageType::IMAGE_BUFFER:
682       return {absl::Substitute(
683           "(((($3) * slices + ($2)) * height + ($1)) * width + ($0))", x, y, s,
684           z)};
685     case TensorStorageType::TEXTURE_2D:
686       return {absl::Substitute("(($0) * depth + ($1))", x, z),
687               absl::Substitute("(($0) * slices + ($1))", y, s)};
688     case TensorStorageType::SINGLE_TEXTURE_2D:
689       return {absl::Substitute("(($0) * depth + ($1))", x, z),
690               absl::Substitute("($0)", y)};
691     case TensorStorageType::TEXTURE_ARRAY:
692     case TensorStorageType::TEXTURE_3D:
693       return {absl::Substitute("($0)", x), absl::Substitute("($0)", y),
694               absl::Substitute("(($0) * slices + ($1))", z, s)};
695     case TensorStorageType::UNKNOWN:
696       return {""};
697     default:
698       return {""};
699   }
700 }
701 
GetPhysicalCoordsWHDSB(const std::string & x,const std::string & y,const std::string & z,const std::string & s,const std::string & b) const702 std::vector<std::string> TensorDescriptor::GetPhysicalCoordsWHDSB(
703     const std::string& x, const std::string& y, const std::string& z,
704     const std::string& s, const std::string& b) const {
705   switch (storage_type) {
706     case TensorStorageType::BUFFER:
707     case TensorStorageType::IMAGE_BUFFER:
708       return {absl::Substitute(
709           "((((($4) * slices + ($3)) * height + $2) * width + ($1)) * batch + "
710           "($0))",
711           b, x, y, s, z)};
712     case TensorStorageType::TEXTURE_2D:
713       return {absl::Substitute("((($0)*batch + ($1))*depth + ($2))", x, b, z),
714               absl::Substitute("(($0) * slices + ($1))", y, s)};
715     case TensorStorageType::SINGLE_TEXTURE_2D:
716       return {absl::Substitute("((($0)*batch + ($1))*depth + ($2))", x, b, z),
717               absl::Substitute("($0)", y)};
718     case TensorStorageType::TEXTURE_ARRAY:
719     case TensorStorageType::TEXTURE_3D:
720       return {absl::Substitute("(($0) * batch + ($1))", x, b),
721               absl::Substitute("($0)", y),
722               absl::Substitute("(($0) * slices + ($1))", z, s)};
723     case TensorStorageType::UNKNOWN:
724       return {""};
725     default:
726       return {""};
727   }
728 }
729 
GetGlobalAddressNoDeclaration(const std::string & xc,const std::string & yc,const std::string & zc,const std::string & sc,const std::string & bc) const730 std::string TensorDescriptor::GetGlobalAddressNoDeclaration(
731     const std::string& xc, const std::string& yc, const std::string& zc,
732     const std::string& sc, const std::string& bc) const {
733   auto coords = GetPhysicalCoords(xc, yc, zc, sc, bc);
734   switch (storage_type) {
735     case TensorStorageType::BUFFER:
736     case TensorStorageType::IMAGE_BUFFER: {
737       return coords[0];
738     }
739     case TensorStorageType::TEXTURE_2D:
740     case TensorStorageType::SINGLE_TEXTURE_2D:
741       return absl::Substitute("(int2)($0, $1)", coords[0], coords[1]);
742     case TensorStorageType::TEXTURE_ARRAY:
743     case TensorStorageType::TEXTURE_3D:
744       return absl::Substitute("(int4)($0, $1, $2, 0)", coords[0], coords[1],
745                               coords[2]);
746     case TensorStorageType::UNKNOWN:
747       return "error";
748   }
749 }
750 
GetPhysicalCoords(const std::string & xc,const std::string & yc,const std::string & zc,const std::string & sc,const std::string & bc) const751 std::vector<std::string> TensorDescriptor::GetPhysicalCoords(
752     const std::string& xc, const std::string& yc, const std::string& zc,
753     const std::string& sc, const std::string& bc) const {
754   if (layout == Layout::HWC || (IsBatchedWidth() && layout == Layout::BHWC)) {
755     return GetPhysicalCoordsWHS(xc, yc, sc);
756   } else if (layout == Layout::BHWC) {
757     return GetPhysicalCoordsWHSB(xc, yc, sc, bc);
758   } else if (layout == Layout::HWDC ||
759              (IsBatchedWidth() && layout == Layout::BHWDC)) {
760     return GetPhysicalCoordsWHDS(xc, yc, zc, sc);
761   } else if (layout == Layout::BHWDC) {
762     return GetPhysicalCoordsWHDSB(xc, yc, zc, sc, bc);
763   } else {
764     return {""};
765   }
766 }
767 
GetDataTypeFromTemplateArgs(const std::string & template_arg,DataType * result) const768 absl::Status TensorDescriptor::GetDataTypeFromTemplateArgs(
769     const std::string& template_arg, DataType* result) const {
770   std::string read_type = template_arg;
771   if (read_type == "FLT" || read_type == "ACCUM_FLT") {
772     auto it = state_vars_.find(read_type);
773     if (it == state_vars_.end()) {
774       return absl::UnavailableError(absl::StrCat(
775           "Read selector template argument ", read_type, " uninitialized."));
776     } else {
777       read_type = it->second;
778     }
779   }
780 
781   if (read_type == "half") {
782     *result = DataType::FLOAT16;
783   } else if (read_type == "float") {
784     *result = DataType::FLOAT32;
785   } else {
786     return absl::NotFoundError(absl::StrCat(
787         "Unrecognized Read selector template argument - ", read_type));
788   }
789   return absl::OkStatus();
790 }
791 
HasAxis(Axis axis) const792 bool TensorDescriptor::HasAxis(Axis axis) const {
793   if (axis == Axis::WIDTH || axis == Axis::HEIGHT || axis == Axis::CHANNELS) {
794     return true;
795   }
796   if (axis == Axis::BATCH &&
797       (layout == Layout::BHWC || layout == Layout::BHWDC)) {
798     return true;
799   }
800   if (axis == Axis::DEPTH &&
801       (layout == Layout::HWDC || layout == Layout::BHWDC)) {
802     return true;
803   }
804   return false;
805 }
806 
GetWidthSize(BHWDC shape) const807 int TensorDescriptor::GetWidthSize(BHWDC shape) const {
808   int width = shape.w;
809   auto it = state_vars_.find("BatchedWidth");
810   if (it != state_vars_.end() && it->second == "true") {
811     width *= shape.b;
812   }
813   auto it1 = state_vars_.find("ElementsX2");
814   if (it1 != state_vars_.end() && it1->second == "true") {
815     width /= 2;
816   }
817   auto it2 = state_vars_.find("ElementsX4");
818   if (it2 != state_vars_.end() && it2->second == "true") {
819     width /= 4;
820   }
821   return width;
822 }
823 
GetSliceStrideSize(BHWDC shape) const824 int TensorDescriptor::GetSliceStrideSize(BHWDC shape) const {
825   if (IsBatchedWidth()) {
826     return GetWidthSize(shape) * shape.h;
827   } else {
828     if (HasAxis(Axis::BATCH)) {
829       return GetWidthSize(shape) * shape.h * shape.b;
830     } else {
831       return GetWidthSize(shape) * shape.h;
832     }
833   }
834 }
835 
SetAddressMode(AddressMode mode)836 void TensorDescriptor::SetAddressMode(AddressMode mode) {
837   if (mode == AddressMode::kZero) {
838     state_vars_["TextureMode"] = "ZERO";
839   } else {
840     state_vars_["TextureMode"] = "DONT_CARE";
841   }
842 }
843 
ParseCoordsFromArgs(const std::vector<std::string> & args,int offset,std::string * xc,std::string * yc,std::string * zc,std::string * sc,std::string * bc) const844 bool TensorDescriptor::ParseCoordsFromArgs(const std::vector<std::string>& args,
845                                            int offset, std::string* xc,
846                                            std::string* yc, std::string* zc,
847                                            std::string* sc,
848                                            std::string* bc) const {
849   if (HasAxis(Axis::WIDTH)) {
850     if (offset >= args.size()) return false;
851     *xc = args[offset++];
852   }
853   if (HasAxis(Axis::HEIGHT)) {
854     if (offset >= args.size()) return false;
855     *yc = args[offset++];
856   }
857   if (HasAxis(Axis::DEPTH)) {
858     if (offset >= args.size()) return false;
859     *zc = args[offset++];
860   }
861   if (HasAxis(Axis::CHANNELS)) {
862     if (offset >= args.size()) {
863       auto it = state_vars_.find("slice_id");
864       if (it == state_vars_.end()) {
865         return false;
866       } else {
867         *sc = it->second;
868       }
869     } else {
870       *sc = args[offset++];
871     }
872   }
873   if (HasAxis(Axis::BATCH) && !IsBatchedWidth()) {
874     if (offset >= args.size()) {
875       auto it = state_vars_.find("batch_id");
876       if (it == state_vars_.end()) {
877         return false;
878       } else {
879         *bc = it->second;
880       }
881     } else {
882       *bc = args[offset++];
883     }
884   }
885   return true;
886 }
887 
IsBatchedWidth() const888 bool TensorDescriptor::IsBatchedWidth() const {
889   auto it = state_vars_.find("BatchedWidth");
890   return it != state_vars_.end() && it->second == "true";
891 }
892 
AddressModeFromState() const893 AddressMode TensorDescriptor::AddressModeFromState() const {
894   auto it = state_vars_.find("TextureMode");
895   if (it != state_vars_.end()) {
896     if (it->second == "ZERO") {
897       return AddressMode::kZero;
898     } else {
899       return AddressMode::kDontCare;
900     }
901   } else {
902     return AddressMode::kDontCare;
903   }
904 }
905 
UploadData(const tflite::gpu::Tensor<BHWC,DataType::FLOAT32> & src)906 void TensorDescriptor::UploadData(
907     const tflite::gpu::Tensor<BHWC, DataType::FLOAT32>& src) {
908   shape = BHWDC(src.shape.b, src.shape.h, src.shape.w, 1, src.shape.c);
909   UploadData(src.data.data());
910 }
911 
UploadData(const tflite::gpu::Tensor<HWC,DataType::FLOAT32> & src)912 void TensorDescriptor::UploadData(
913     const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& src) {
914   shape = BHWDC(1, src.shape.h, src.shape.w, 1, src.shape.c);
915   UploadData(src.data.data());
916 }
917 
UploadData(const tflite::gpu::Tensor<Linear,DataType::FLOAT32> & src)918 void TensorDescriptor::UploadData(
919     const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& src) {
920   shape = BHWDC(1, 1, 1, 1, src.shape.v);
921   UploadData(src.data.data());
922 }
923 
UploadData(const float * src)924 void TensorDescriptor::UploadData(const float* src) {
925   int aligned_channels = storage_type == TensorStorageType::SINGLE_TEXTURE_2D
926                              ? shape.c
927                              : AlignByN(shape.c, 4);
928   int elements_count = shape.b * shape.w * shape.h * shape.d * aligned_channels;
929   data.resize(elements_count * SizeOf(data_type));
930   if (data_type == DataType::FLOAT32) {
931     float* gpu_data = reinterpret_cast<float*>(data.data());
932     DataFromBHWDC(src, shape, *this, gpu_data);
933   } else {
934     half* gpu_data = reinterpret_cast<half*>(data.data());
935     DataFromBHWDC(src, shape, *this, gpu_data);
936   }
937 }
938 
SupportsZeroClamp(const Axis & axis) const939 bool TensorDescriptor::SupportsZeroClamp(const Axis& axis) const {
940   switch (storage_type) {
941     case TensorStorageType::UNKNOWN:
942       return false;
943     case TensorStorageType::BUFFER:
944     case TensorStorageType::IMAGE_BUFFER:
945       return false;
946     case TensorStorageType::TEXTURE_ARRAY:
947     case TensorStorageType::TEXTURE_2D:
948     case TensorStorageType::SINGLE_TEXTURE_2D:
949       return axis == Axis::WIDTH || axis == Axis::HEIGHT;
950     case TensorStorageType::TEXTURE_3D:
951       return axis == Axis::WIDTH || axis == Axis::HEIGHT || axis == Axis::DEPTH;
952   }
953 }
954 
CanReadOutOfBorder(const Axis & axis) const955 bool TensorDescriptor::CanReadOutOfBorder(const Axis& axis) const {
956   switch (storage_type) {
957     case TensorStorageType::UNKNOWN:
958       return false;
959     case TensorStorageType::BUFFER:
960       return false;
961     case TensorStorageType::IMAGE_BUFFER:
962     case TensorStorageType::TEXTURE_2D:
963     case TensorStorageType::TEXTURE_3D:
964     case TensorStorageType::SINGLE_TEXTURE_2D:
965     case TensorStorageType::TEXTURE_ARRAY:
966       return true;
967   }
968 }
969 
IsLinear() const970 bool TensorDescriptor::IsLinear() const {
971   return storage_type == TensorStorageType::BUFFER ||
972          storage_type == TensorStorageType::IMAGE_BUFFER;
973 }
974 
ReturnsZeroForNegOneRead() const975 bool TensorDescriptor::ReturnsZeroForNegOneRead() const {
976   return storage_type == TensorStorageType::IMAGE_BUFFER;
977 }
978 
979 namespace {
GetLinearIndex(const TensorDescriptor & desc,const BHWDC & shape,int b,int x,int y,int d,int s,int sub_c)980 int GetLinearIndex(const TensorDescriptor& desc, const BHWDC& shape, int b,
981                    int x, int y, int d, int s, int sub_c) {
982   const int slices = DivideRoundUp(shape.c, 4);
983   switch (desc.storage_type) {
984     case TensorStorageType::BUFFER:
985     case TensorStorageType::IMAGE_BUFFER:
986     case TensorStorageType::TEXTURE_ARRAY:
987     case TensorStorageType::TEXTURE_3D:
988       return ((((d * slices + s) * shape.h + y) * shape.w + x) * shape.b + b) *
989                  4 +
990              sub_c;  // DSHWBC4
991     case TensorStorageType::TEXTURE_2D:
992       return ((((y * slices + s) * shape.w + x) * shape.b + b) * shape.d + d) *
993                  4 +
994              sub_c;  // HSWBDC4
995     case TensorStorageType::SINGLE_TEXTURE_2D:
996       return (((y * shape.w + x) * shape.b + b) * shape.d + d) * shape.c +
997              sub_c;  // HWBDC
998     case TensorStorageType::UNKNOWN:
999       return -1;
1000   }
1001 }
1002 
GetChannelsAlignment(const TensorDescriptor & desc,const BHWDC & shape)1003 int GetChannelsAlignment(const TensorDescriptor& desc, const BHWDC& shape) {
1004   return desc.storage_type == TensorStorageType::SINGLE_TEXTURE_2D ? shape.c
1005                                                                    : 4;
1006 }
1007 }  // namespace
1008 
1009 template <typename FromType, typename ToType>
DataFromBHWDC(const FromType * src,const BHWDC & shape,const TensorDescriptor & desc,ToType * dst)1010 void DataFromBHWDC(const FromType* src, const BHWDC& shape,
1011                    const TensorDescriptor& desc, ToType* dst) {
1012   const int channels_alignment = GetChannelsAlignment(desc, shape);
1013   const int slices = DivideRoundUp(shape.c, 4);
1014   for (int b = 0; b < shape.b; ++b) {
1015     for (int s = 0; s < slices; ++s) {
1016       for (int y = 0; y < shape.h; ++y) {
1017         for (int x = 0; x < shape.w; ++x) {
1018           for (int d = 0; d < shape.d; ++d) {
1019             for (int c = 0; c < channels_alignment; ++c) {
1020               FromType value;
1021               if (s * 4 + c < shape.c) {
1022                 const int cpu_index =
1023                     shape.LinearIndex({b, y, x, d, s * 4 + c});
1024                 value = src[cpu_index];
1025               } else {
1026                 value = 0;
1027               }
1028               int gpu_index = GetLinearIndex(desc, shape, b, x, y, d, s, c);
1029               dst[gpu_index] = value;
1030             }
1031           }
1032         }
1033       }
1034     }
1035   }
1036 }
1037 
1038 template void DataFromBHWDC<float, float>(const float* src, const BHWDC& shape,
1039                                           const TensorDescriptor& desc,
1040                                           float* dst);
1041 template void DataFromBHWDC<float, half>(const float* src, const BHWDC& shape,
1042                                          const TensorDescriptor& desc,
1043                                          half* dst);
1044 template void DataFromBHWDC<int32_t, int32_t>(const int32_t* src,
1045                                               const BHWDC& shape,
1046                                               const TensorDescriptor& desc,
1047                                               int32_t* dst);
1048 template void DataFromBHWDC<int16_t, int16_t>(const int16_t* src,
1049                                               const BHWDC& shape,
1050                                               const TensorDescriptor& desc,
1051                                               int16_t* dst);
1052 template void DataFromBHWDC<int8_t, int8_t>(const int8_t* src,
1053                                             const BHWDC& shape,
1054                                             const TensorDescriptor& desc,
1055                                             int8_t* dst);
1056 template void DataFromBHWDC<uint32_t, uint32_t>(const uint32_t* src,
1057                                                 const BHWDC& shape,
1058                                                 const TensorDescriptor& desc,
1059                                                 uint32_t* dst);
1060 template void DataFromBHWDC<uint16_t, uint16_t>(const uint16_t* src,
1061                                                 const BHWDC& shape,
1062                                                 const TensorDescriptor& desc,
1063                                                 uint16_t* dst);
1064 template void DataFromBHWDC<uint8_t, uint8_t>(const uint8_t* src,
1065                                               const BHWDC& shape,
1066                                               const TensorDescriptor& desc,
1067                                               uint8_t* dst);
1068 
1069 template <typename FromType, typename ToType>
DataToBHWDC(const FromType * src,const BHWDC & shape,const TensorDescriptor & desc,ToType * dst)1070 void DataToBHWDC(const FromType* src, const BHWDC& shape,
1071                  const TensorDescriptor& desc, ToType* dst) {
1072   const int channels_alignment = GetChannelsAlignment(desc, shape);
1073   const int slices = DivideRoundUp(shape.c, 4);
1074   for (int b = 0; b < shape.b; ++b) {
1075     for (int s = 0; s < slices; ++s) {
1076       for (int y = 0; y < shape.h; ++y) {
1077         for (int x = 0; x < shape.w; ++x) {
1078           for (int d = 0; d < shape.d; ++d) {
1079             for (int c = 0; c < channels_alignment; ++c) {
1080               if (s * 4 + c >= shape.c) {
1081                 continue;
1082               }
1083               int cpu_index = shape.LinearIndex({b, y, x, d, s * 4 + c});
1084               int gpu_index = GetLinearIndex(desc, shape, b, x, y, d, s, c);
1085               dst[cpu_index] = src[gpu_index];
1086             }
1087           }
1088         }
1089       }
1090     }
1091   }
1092 }
1093 
1094 template void DataToBHWDC<float, float>(const float* src, const BHWDC& shape,
1095                                         const TensorDescriptor& desc,
1096                                         float* dst);
1097 template void DataToBHWDC<half, float>(const half* src, const BHWDC& shape,
1098                                        const TensorDescriptor& desc,
1099                                        float* dst);
1100 template void DataToBHWDC<int32_t, int32_t>(const int32_t* src,
1101                                             const BHWDC& shape,
1102                                             const TensorDescriptor& desc,
1103                                             int32_t* dst);
1104 template void DataToBHWDC<int16_t, int16_t>(const int16_t* src,
1105                                             const BHWDC& shape,
1106                                             const TensorDescriptor& desc,
1107                                             int16_t* dst);
1108 template void DataToBHWDC<int8_t, int8_t>(const int8_t* src, const BHWDC& shape,
1109                                           const TensorDescriptor& desc,
1110                                           int8_t* dst);
1111 template void DataToBHWDC<uint32_t, uint32_t>(const uint32_t* src,
1112                                               const BHWDC& shape,
1113                                               const TensorDescriptor& desc,
1114                                               uint32_t* dst);
1115 template void DataToBHWDC<uint16_t, uint16_t>(const uint16_t* src,
1116                                               const BHWDC& shape,
1117                                               const TensorDescriptor& desc,
1118                                               uint16_t* dst);
1119 template void DataToBHWDC<uint8_t, uint8_t>(const uint8_t* src,
1120                                             const BHWDC& shape,
1121                                             const TensorDescriptor& desc,
1122                                             uint8_t* dst);
1123 
1124 }  // namespace gpu
1125 }  // namespace tflite
1126