1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_INFERENCE_CONTEXT_H_ 17 #define TENSORFLOW_LITE_DELEGATES_GPU_METAL_INFERENCE_CONTEXT_H_ 18 19 #import <Metal/Metal.h> 20 21 #include <list> 22 #include <map> 23 #include <vector> 24 25 #include "absl/container/flat_hash_map.h" 26 #include "tensorflow/lite/delegates/gpu/common/model.h" 27 #include "tensorflow/lite/delegates/gpu/common/model_hints.h" 28 #include "tensorflow/lite/delegates/gpu/common/precision.h" 29 #include "tensorflow/lite/delegates/gpu/common/shape.h" 30 #include "tensorflow/lite/delegates/gpu/common/status.h" 31 #include "tensorflow/lite/delegates/gpu/common/task/profiling_info.h" 32 #include "tensorflow/lite/delegates/gpu/common/task/tuning_type.h" 33 #include "tensorflow/lite/delegates/gpu/metal/compute_task.h" 34 #include "tensorflow/lite/delegates/gpu/metal/metal_device.h" 35 #include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h" 36 37 namespace tflite { 38 namespace gpu { 39 namespace metal { 40 41 struct MetalNode { 42 ComputeTask task; 43 std::vector<ValueId> inputs; 44 std::vector<ValueId> outputs; 45 46 // Mostly for debug purposes. 47 std::string name; 48 49 MetalNode() = default; 50 51 MetalNode(MetalNode&& node) = default; 52 MetalNode& operator=(MetalNode&& node) = default; 53 MetalNode(const MetalNode&) = delete; 54 MetalNode& operator=(const MetalNode&) = delete; 55 }; 56 57 class InferenceContext { 58 public: 59 struct CreateInferenceInfo { 60 CalculationsPrecision precision; 61 TensorStorageType storage_type; 62 ModelHints hints; 63 }; 64 65 InferenceContext() = default; 66 67 // IMPORTANT: If InitFromGraph used, RunGraphTransforms must be applied for 68 // this graph upfront, otherwise not guaranteed correct behavior 69 absl::Status InitFromGraph(const CreateInferenceInfo& create_info, 70 const GraphFloat32& graph, 71 id<MTLDevice> device_id); 72 73 // Applies specific transformations to the graph before the 74 // initialization. These transformations are either impossible or useless in 75 // other backends. 76 absl::Status InitFromGraphWithTransforms( 77 const CreateInferenceInfo& create_info, GraphFloat32* graph, 78 id<MTLDevice> device_id); 79 80 // Updates MTLBuffer handles in MetalSpatialTensors and kernels that use this 81 // tensors. 82 void UpdatePreallocatedTensors( 83 const std::map<ValueId, id<MTLBuffer>>& preallocated); 84 85 /// Inserts all GPU compute tasks into the command encoder. 86 /// @param inputOutputBuffers Must be created and passed into the method 87 /// with pairs ID:buffer 88 /// @discussion No GPU synchronization functions are used inside. All GPU 89 /// resources must be created 90 /// with the same device which has been used in 91 /// compileModelWithDevice() method. 92 void EncodeWithEncoder(id<MTLComputeCommandEncoder> command_encoder); 93 94 /// Inserts all GPU compute tasks into the command buffer. For every task will 95 /// be used separate 96 /// encoder. 97 /// @param inputOutputBuffers Must be created and passed into the method with 98 /// pairs ID:buffer 99 /// @discussion No GPU synchronization functions are used inside. All GPU 100 /// resources must be created 101 /// with the same device which has been used in 102 /// compileModelWithDevice() method. 103 void EncodeWithCommandBuffer(id<MTLCommandBuffer> command_buffer); 104 105 /// Adds all GPU compute tasks to the command queue. For every task will be 106 /// used separate 107 /// encoder. Few encoders(flushPeriod) batched into compute buffer that sent 108 /// for execution. 109 /// @param inputOutputBuffers Must be created and passed into the method with 110 /// pairs ID:buffer 111 /// @discussion No GPU synchronization functions are used inside. All GPU 112 /// resources must be created 113 /// with the same device which has been used in 114 /// compileModelWithDevice() method. 115 void EncodeWithCommandQueue(id<MTLCommandQueue> command_queue, 116 int flush_period); 117 118 void Profile(id<MTLDevice> device, ProfilingInfo* result); 119 120 private: 121 enum class TensorMemoryType { 122 kStrongShape, 123 kBuffer, 124 kVariable, 125 kConst, 126 kPreallocated 127 }; 128 absl::Status Compile(const GraphFloat32& graph, const GpuInfo& gpu_info, 129 ModelHints hints); 130 131 void ReserveGraphTensors(const CreateInferenceInfo& create_info, 132 const GpuInfo& gpu_info, const GraphFloat32& graph, 133 const std::set<ValueId>& preallocated_ids); 134 135 absl::Status CompileOperations(MetalDevice* device); 136 137 absl::Status Merge(); 138 absl::Status AllocateTensors(MetalDevice* device, 139 const std::set<ValueId>& preallocated_ids); 140 absl::Status AllocateMemoryForConstTensors(MetalDevice* device); 141 absl::Status AllocateMemoryForBuffers(MetalDevice* device); 142 absl::Status AllocateMemoryForStrongShapes(MetalDevice* device); 143 void BindTensorsToOperations(); 144 absl::Status UpdateParams(const GpuInfo& gpu_info); 145 MetalSpatialTensor* GetTensor(ValueId tensor_id); 146 void GetUsages(const std::function<bool(ValueId)>& functor, 147 std::map<ValueId, int2>* usages); 148 TensorMemoryType GetTensorMemoryType(ValueId id); 149 absl::Status Tune(TuningType tuning_type, MetalDevice* device); 150 151 struct DummyTensor { 152 BHWC shape; 153 TensorDescriptor descriptor; 154 155 bool operator==(const DummyTensor& b) const { 156 return shape == b.shape && descriptor == b.descriptor; 157 } 158 }; 159 160 class TensorReserver { 161 public: TensorReserver()162 TensorReserver() : next_(0) {} Add(const DummyTensor & dummy)163 ValueId Add(const DummyTensor& dummy) { 164 reservations_[next_] = dummy; 165 return next_++; 166 } Add(ValueId id,const DummyTensor & dummy)167 void Add(ValueId id, const DummyTensor& dummy) { 168 reservations_[id] = dummy; 169 } SetNext(ValueId id)170 void SetNext(ValueId id) { next_ = id; } Get(ValueId id)171 DummyTensor Get(ValueId id) { return reservations_[id]; } 172 GetTensorDescs()173 std::vector<std::pair<ValueId, TensorDescriptor>> GetTensorDescs() const { 174 std::vector<std::pair<ValueId, TensorDescriptor>> result; 175 for (auto& v : reservations_) { 176 TensorDescriptor desc = v.second.descriptor; 177 desc.shape.b = v.second.shape.b; 178 desc.shape.h = v.second.shape.h; 179 desc.shape.w = v.second.shape.w; 180 desc.shape.d = 1; 181 desc.shape.c = v.second.shape.c; 182 result.push_back({v.first, desc}); 183 } 184 return result; 185 } 186 Add(const std::vector<std::pair<ValueId,TensorDescriptor>> & tensors)187 void Add(const std::vector<std::pair<ValueId, TensorDescriptor>>& tensors) { 188 for (auto& v : tensors) { 189 DummyTensor dummy; 190 dummy.descriptor = v.second; 191 dummy.shape.b = v.second.shape.b; 192 dummy.shape.h = v.second.shape.h; 193 dummy.shape.w = v.second.shape.w; 194 dummy.shape.c = v.second.shape.c; 195 Add(v.first, dummy); 196 } 197 } 198 199 private: 200 absl::flat_hash_map<ValueId, DummyTensor> reservations_; 201 ValueId next_; 202 }; 203 TensorReserver tensor_reserver_; 204 205 std::vector<MetalNode> nodes_; 206 // contains indexes of compute_tasks_ 207 std::vector<int> task_ids_with_preallocated_tensors_; 208 std::vector<ValueId> input_ids_; 209 std::vector<ValueId> output_ids_; 210 CalculationsPrecision precision_; 211 std::map<ValueId, MetalSpatialTensor> preallocated_tensors_; 212 213 std::map<ValueId, TensorDescriptor> const_tensors_descs_; 214 std::map<ValueId, MetalSpatialTensor> const_tensors_; 215 216 std::map<ValueId, int> graph_ids_to_shared_buffer_tensors_; 217 std::vector<id<MTLBuffer>> shared_buffers_; 218 std::vector<MetalSpatialTensor> 219 shared_buffer_tensors_; // use references to memory 220 // from _sharedBuffers 221 222 std::map<ValueId, MetalSpatialTensor> strong_shape_tensors_; 223 std::map<ValueId, ValueId> graph_ids_to_strong_shape_tensors_; 224 }; 225 226 // Runs specific transforms for the graph. 227 absl::Status RunGraphTransforms(GraphFloat32* graph); 228 229 } // namespace metal 230 } // namespace gpu 231 } // namespace tflite 232 233 #endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_INFERENCE_CONTEXT_H_ 234