1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
10
11 namespace vkcompute {
12
add_storage_type_suffix(std::string & kernel_name,const utils::StorageType storage_type)13 void add_storage_type_suffix(
14 std::string& kernel_name,
15 const utils::StorageType storage_type) {
16 switch (storage_type) {
17 case utils::kBuffer:
18 kernel_name += "_buffer";
19 break;
20 case utils::kTexture3D:
21 kernel_name += "_texture3d";
22 break;
23 case utils::kTexture2D:
24 kernel_name += "_texture2d";
25 break;
26 }
27 }
28
add_storage_type_suffix(std::string & kernel_name,const api::vTensor & tensor)29 void add_storage_type_suffix(
30 std::string& kernel_name,
31 const api::vTensor& tensor) {
32 return add_storage_type_suffix(kernel_name, tensor.storage_type());
33 }
34
add_dtype_suffix(std::string & kernel_name,const vkapi::ScalarType dtype)35 void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) {
36 switch (dtype) {
37 case vkapi::kFloat:
38 kernel_name += "_float";
39 break;
40 case vkapi::kHalf:
41 kernel_name += "_half";
42 break;
43 case vkapi::kInt:
44 kernel_name += "_int";
45 break;
46 case vkapi::kChar:
47 case vkapi::kQInt8:
48 kernel_name += "_int8";
49 break;
50 case vkapi::kByte:
51 case vkapi::kQUInt8:
52 kernel_name += "_uint8";
53 break;
54 default:
55 break;
56 }
57 }
58
add_dtype_suffix(std::string & kernel_name,const api::vTensor & tensor)59 void add_dtype_suffix(std::string& kernel_name, const api::vTensor& tensor) {
60 return add_dtype_suffix(kernel_name, tensor.dtype());
61 }
62
add_ndim_suffix(std::string & kernel_name,const api::vTensor & tensor)63 void add_ndim_suffix(std::string& kernel_name, const api::vTensor& tensor) {
64 switch (tensor.storage_type()) {
65 case utils::kTexture3D:
66 kernel_name += "_3d";
67 break;
68 case utils::kTexture2D:
69 kernel_name += "_2d";
70 break;
71 default:
72 break;
73 }
74 }
75
add_packed_dim_suffix(std::string & kernel_name,const int32_t packed_dim)76 void add_packed_dim_suffix(std::string& kernel_name, const int32_t packed_dim) {
77 switch (packed_dim) {
78 case WHCN::kWidthDim:
79 kernel_name += "_W_packed";
80 break;
81 case WHCN::kHeightDim:
82 kernel_name += "_H_packed";
83 break;
84 case WHCN::kChannelsDim:
85 kernel_name += "_C_packed";
86 break;
87 default:
88 VK_THROW("Invalid packed dim!");
89 }
90 }
91
add_packed_dim_suffix(std::string & kernel_name,const api::vTensor & tensor)92 void add_packed_dim_suffix(
93 std::string& kernel_name,
94 const api::vTensor& tensor) {
95 return add_packed_dim_suffix(kernel_name, tensor.packed_dim());
96 }
97
98 } // namespace vkcompute
99