// // Copyright (c) 2023 Apple Inc. All rights reserved. // Provided subject to the LICENSE file in the top level directory. // // clang-format off #pragma once #include #include #include #include #include #include #include namespace executorch { namespace backends { namespace mps { namespace delegate { class MPSExecutor { private: MPSGraphExecutable* _executable; NSArray * _inputShapes; NSArray * _outputShapes; NSMutableArray* _inputsArray; NSMutableArray* _outputsArray; // Flag whatever to use shared memory or not // Shared memory flag will be set as following (based on HW and target config): // - True: Apple Silicon and macOS15+/iOS17+/iPadOS17+ // - False: Simulator or x86 or pre-macOS15/iOS17/iPadOS17 bool _use_shared_mem; bool _buffers_initialized; // Input/Output GPU buffer pointer std::vector> _inputGPUBuffers; std::vector> _outputGPUBuffers; // Input/Output CPU buffer pointers std::vector _inputCPUBuffers; std::vector _outputCPUBuffers; std::unordered_map _mpsGraphTensorToId; public: MPSExecutor(); ~MPSExecutor() { if (_inputsArray) { [_inputsArray release]; _inputsArray = nil; } if (_outputsArray) { [_outputsArray release]; } _inputsArray = nil; _outputsArray = nil; } inline size_t getNumInputs() { return [_inputShapes count]; } inline size_t getNumOutputs() { return [_outputShapes count]; } inline MPSGraphExecutable* getMPSGraphExecutable() { return _executable; } ET_NODISCARD executorch::runtime::Error forward(std::vector& outputs); ET_NODISCARD executorch::runtime::Error set_inputs_outputs(std::vector& inputs, std::vector& outputs); executorch::runtime::Error initDataBuffers(); executorch::runtime::Error updateDataBuffers(std::vector& inputs, std::vector& outputs); executorch::runtime::Error syncOutputBuffers(std::vector& outputs); friend class MPSCompiler; }; } // namespace delegate } // namespace mps } // namespace backends } // namespace executorch // clang-format on