1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_H_ 17 #define TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_H_ 18 19 #import <Metal/Metal.h> 20 21 #include <map> 22 #include <string> 23 #include <vector> 24 25 #include "tensorflow/lite/delegates/gpu/common/model.h" 26 #include "tensorflow/lite/delegates/gpu/common/shape.h" 27 #include "tensorflow/lite/delegates/gpu/common/status.h" 28 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" 29 #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" 30 31 @interface TFLComputeTask : NSObject 32 33 /// Returns empty string or error if shader can't be compiled. 34 - (::tflite::gpu::Status)compileWithDevice:(id<MTLDevice>)device 35 taskDescriptor:(::tflite::gpu::metal::ComputeTaskDescriptorPtr)desc 36 runtimeOptions:(const ::tflite::gpu::metal::RuntimeOptions&)options; 37 38 /// Updates dimensions for inputs/outputs/intermediate tensors 39 - (::tflite::gpu::Status) 40 setInputDimensionsWithDevice:(id<MTLDevice>)device 41 dimensions:(std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>*)dimensions; 42 43 /// Updates buffers for intermediate tensors only. Returns error if out of memory or a buffer is 44 /// larger than MTLDevice can support. 45 /// @param buffers is a map from intermediate tensors' ValueId to metal handles with corresponding 46 /// buffers. 47 /// @param outputIDs must match the output of added operations. 48 /// @param usageRecordIds is a map from intermediate tensors' ValueId to corresponding tensor usage 49 /// records ids. 50 /// @param sharedBufferIds contain shared buffer id for each tensor usage record id. 51 /// @param sharedBuffers contain metal handles to the allocated buffers for each shared buffer id. 52 /// TODO(ypisarchyk): probably we can decrease the number of parameters here 53 - (::tflite::gpu::Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id<MTLBuffer>>*)buffers 54 outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds 55 usageRecordIds: 56 (const std::map<::tflite::gpu::ValueId, size_t>&)usageRecordIds 57 sharedBufferIds:(const std::vector<size_t>&)sharedBufferIds 58 sharedBuffers:(const std::vector<id<MTLBuffer>>&)sharedBuffers; 59 60 - (void)encodeWithEncoder:(id<MTLComputeCommandEncoder>)encoder 61 inputOutputBuffers: 62 (const std::map<::tflite::gpu::ValueId, id<MTLBuffer>>&)inputOutputBuffers; 63 64 @end 65 66 #endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_H_ 67