• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_DESCRIPTOR_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_DESCRIPTOR_H_
18 
19 #include <cstdint>
20 #include <functional>
21 #include <map>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "tensorflow/lite/delegates/gpu/common/model.h"
27 #include "tensorflow/lite/delegates/gpu/common/shape.h"
28 #include "tensorflow/lite/delegates/gpu/common/types.h"
29 #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
30 
31 namespace tflite {
32 namespace gpu {
33 namespace metal {
34 
35 using OutputDimensions =
36     std::function<BHWC(const std::map<ValueId, BHWC>& buffers)>;
37 using UniformsFunction =
38     std::function<std::vector<uint8_t>(const std::map<ValueId, BHWC>& buffers)>;
39 using DispatchParamsFunction = std::function<std::pair<uint3, uint3>(
40     const std::map<ValueId, BHWC>& buffers)>;
41 
42 // Compute task descriptor contains a linkable shader code or a code for
43 // complete shader to which other linkable can be attached or not. An operation
44 // can produce one or more descriptors and graph compiler uses descriptors as
45 // building blocks. All required data like immutable operation parameters
46 // (weights etc.) is attached to the descriptor.
47 struct ComputeTaskDescriptor {
48   struct InputBufferDescriptor {
49     ValueId id;
50     // The declaration is inserted into the compute function arguments list.
51     // Example for non-linkable task: "device FLT4* const input_buffer"
52     // Example for linkable: "device FLT4* const"
53     std::string declaration;
54   };
55   struct OutputBufferDescriptor {
56     ValueId id;
57     // The declaration is inserted into the compute function arguments list.
58     // Example for non-linkable task: "device FLT4* output_buffer"
59     // Example for linkable: "device FLT4*"
60     std::string declaration;
61     // Multiple outputs are allowed from a linkable operation so after fusion
62     // each buffer's dimensions are calculated separately from different
63     // operations.
64     OutputDimensions dimensions_function;
65     // Fusion absorbs intermediate tensors. Keep this ids to properly store
66     // output dimensions.
67     std::vector<ValueId> alias;
68   };
69   struct ImmutableBufferDescriptor {
70     std::string declaration;
71     std::vector<uint8_t> data;
72   };
73   // Uniforms are recalculated at any setInputDimensions call.
74   struct UniformBufferDescriptor {
75     // The declaration is inserted into the compute function arguments list.
76     // Example: "constant uint4& some_uniforms"
77     std::string declaration;
78     // This function re-calculates uniforms for specific input dimensions.
79     UniformsFunction data_function;
80   };
81 
82   // Unique ID to match the graph compilation errors.
83   int id;
84   bool is_linkable;
85   // A linkable function or a full shader source with 3 parameters $ for
86   // substitute function. Example of linkable: "(FLT4 linkable$0(FLT4 value, int
87   // linear_index) { return value; })" Example of non-linkable function:
88   // #include <metal_stdlib>
89   // using namespace metal;
90   // $0
91   // kernel void ComputeFunction(
92   //                             $1
93   //                             uint3 gid[[thread_position_in_grid]]) {
94   //   if (int(gid.x) >= size.x || int(gid.y) >= size.y) {
95   //     return;
96   //   }
97   //   const int linear_index = (gid.z * size.y + gid.y) * size.x + gid.x;
98   //   FLT4 value = input_buffer[linear_index] + 1.0f;
99   //   $2
100   //   output_buffer[linear_index] = value;
101   // }
102   std::string shader_source;
103   std::vector<InputBufferDescriptor> input_buffers;
104   // A single per-operation output is supported now.
105   OutputBufferDescriptor output_buffer;
106   std::vector<ImmutableBufferDescriptor> immutable_buffers;
107   std::vector<UniformBufferDescriptor> uniform_buffers;
108   // Dynamic resizing of input tensor is supported. User-defined functions to
109   // calculate new parameters for GPU compute task dispatching. A leading
110   // unlinkable task must provide this.
111   DispatchParamsFunction resize_function;
112 };
113 
114 using ComputeTaskDescriptorPtr = std::shared_ptr<ComputeTaskDescriptor>;
115 
116 /// Helper function to convert buffer's content into stream of bytes
117 template <typename T>
GetByteBuffer(const std::vector<T> & input_vector)118 std::vector<uint8_t> GetByteBuffer(const std::vector<T>& input_vector) {
119   std::vector<uint8_t> result;
120   result.insert(result.begin(),
121                 reinterpret_cast<const uint8_t*>(input_vector.data()),
122                 reinterpret_cast<const uint8_t*>(input_vector.data()) +
123                     input_vector.size() * sizeof(*input_vector.data()));
124   return result;
125 }
126 
127 /// Converts float to destination type (if needed) and stores as bytes array.
128 std::vector<uint8_t> GetByteBufferConverted(
129     const std::vector<float>& input_vector,
130     RuntimeOptions::Precision destination_type);
131 
132 /// Resizes, Converts float to destination type (if needed) and stores as bytes
133 /// array.
134 std::vector<uint8_t> GetByteBufferConvertedResized(
135     const std::vector<float>& input_vector,
136     RuntimeOptions::Precision destination_type, size_t elements_count);
137 
138 }  // namespace metal
139 }  // namespace gpu
140 }  // namespace tflite
141 
142 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_DESCRIPTOR_H_
143