• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_H_
18 
19 #import <Metal/Metal.h>
20 
21 #include <map>
22 #include <string>
23 #include <vector>
24 
25 #include "tensorflow/lite/delegates/gpu/common/model.h"
26 #include "tensorflow/lite/delegates/gpu/common/shape.h"
27 #include "tensorflow/lite/delegates/gpu/common/status.h"
28 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
29 #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
30 
31 @interface TFLComputeTask : NSObject
32 
33 /// Returns empty string or error if shader can't be compiled.
34 - (::tflite::gpu::Status)compileWithDevice:(id<MTLDevice>)device
35                             taskDescriptor:(::tflite::gpu::metal::ComputeTaskDescriptorPtr)desc
36                             runtimeOptions:(const ::tflite::gpu::metal::RuntimeOptions&)options;
37 
38 /// Updates dimensions for inputs/outputs/intermediate tensors
39 - (::tflite::gpu::Status)
40     setInputDimensionsWithDevice:(id<MTLDevice>)device
41                       dimensions:(std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>*)dimensions;
42 
43 /// Updates buffers for intermediate tensors only. Returns error if out of memory or a buffer is
44 /// larger than MTLDevice can support.
45 /// @param buffers is a map from intermediate tensors' ValueId to metal handles with corresponding
46 ///        buffers.
47 /// @param outputIDs must match the output of added operations.
48 /// @param usageRecordIds is a map from intermediate tensors' ValueId to corresponding tensor usage
49 /// records ids.
50 /// @param sharedBufferIds contain shared buffer id for each tensor usage record id.
51 /// @param sharedBuffers contain metal handles to the allocated buffers for each shared buffer id.
52 /// TODO(ypisarchyk): probably we can decrease the number of parameters here
53 - (::tflite::gpu::Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id<MTLBuffer>>*)buffers
54                              outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds
55                         usageRecordIds:
56                             (const std::map<::tflite::gpu::ValueId, size_t>&)usageRecordIds
57                        sharedBufferIds:(const std::vector<size_t>&)sharedBufferIds
58                          sharedBuffers:(const std::vector<id<MTLBuffer>>&)sharedBuffers;
59 
60 - (void)encodeWithEncoder:(id<MTLComputeCommandEncoder>)encoder
61        inputOutputBuffers:
62            (const std::map<::tflite::gpu::ValueId, id<MTLBuffer>>&)inputOutputBuffers;
63 
64 @end
65 
66 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_H_
67