#pragma once #ifdef USE_XNNPACK #include #include #include namespace at::native::xnnpack { using SerializationTypeLinearPrePack = std::tuple< Tensor, std::optional, std::optional, std::optional>; using SerializationTypeConv2dPrePack = std::tuple< Tensor, std::optional, std::vector, std::vector, std::vector, int64_t, std::optional, std::optional>; using SerializationTypeTransposeConv2dPrePack = std::tuple< Tensor, std::optional, std::vector, std::vector, std::vector, std::vector, int64_t, std::optional, std::optional>; class LinearOpContext : public torch::jit::CustomClassHolder { protected: Tensor orig_weight_; std::optional orig_bias_; std::optional output_min_; std::optional output_max_; bool orig_weight_and_bias_freed_; public: SerializationTypeLinearPrePack unpack() { TORCH_CHECK(!orig_weight_and_bias_freed_, "Original weight and bias have been freed"); return std::make_tuple(orig_weight_, orig_bias_, output_min_, output_max_); } virtual Tensor run(const Tensor& input) = 0; virtual void free_orig_weight_and_bias() = 0; }; class XNNPackLinearOpContext final : public LinearOpContext { private: ContextLinear op_context_; public: XNNPackLinearOpContext( Tensor&& weight, std::optional&& bias, const std::optional& min, const std::optional& max, ContextLinear&& op_context) : op_context_(std::move(op_context)) { orig_weight_ = std::move(weight); orig_bias_ = std::move(bias); output_min_ = min; output_max_ = max; orig_weight_and_bias_freed_ = false; } Tensor run(const Tensor& input) override; void free_orig_weight_and_bias() override; static c10::intrusive_ptr create_context( Tensor&& weight, std::optional&& bias, const std::optional& output_min, const std::optional& output_max); }; class Conv2dOpContext : public torch::jit::CustomClassHolder { protected: Tensor orig_weight_; std::optional orig_bias_; std::vector stride_; std::vector padding_; std::vector dilation_; int64_t groups_; std::optional output_min_; std::optional output_max_; bool orig_weight_and_bias_freed_; public: SerializationTypeConv2dPrePack unpack() { TORCH_CHECK(!orig_weight_and_bias_freed_, "Original weight and bias have been freed"); return std::make_tuple( orig_weight_, orig_bias_, stride_, padding_, dilation_, groups_, output_min_, output_max_); } virtual Tensor run(const Tensor& input) = 0; virtual void free_orig_weight_and_bias() = 0; }; class TransposeConv2dOpContext : public torch::jit::CustomClassHolder { protected: Tensor orig_weight_; std::optional orig_bias_; std::vector stride_; std::vector padding_; std::vector output_padding_; std::vector dilation_; int64_t groups_; std::optional output_min_; std::optional output_max_; bool orig_weight_and_bias_freed_; public: SerializationTypeTransposeConv2dPrePack unpack() { TORCH_CHECK(!orig_weight_and_bias_freed_, "Original weight and bias have been freed"); return std::make_tuple( orig_weight_, orig_bias_, stride_, padding_, output_padding_, dilation_, groups_, output_min_, output_max_); } virtual Tensor run(const Tensor& input) = 0; virtual void free_orig_weight_and_bias() = 0; }; class XNNPackConv2dOpContext final : public Conv2dOpContext { private: ContextConv2D op_context_; // xnnpack convs use indirection buffer. // These buffers need setup at runtime and/or when input // dims change. If we are running the same model on multiple // threads, this can lead to contention where indirection buffer // is being accessed and updated at the same time from two different // threads. std::mutex xnnp_mutex_; public: XNNPackConv2dOpContext( Tensor&& weight, std::optional&& bias, std::vector&& padding, std::vector&& stride, std::vector&& dilation, uint64_t groups, const std::optional& min, const std::optional& max, ContextConv2D&& op_context) : op_context_(std::move(op_context)) { orig_weight_ = std::move(weight); orig_bias_ = std::move(bias); padding_ = std::move(padding); stride_ = std::move(stride); dilation_ = std::move(dilation); groups_ = groups; output_min_ = min; output_max_ = max; orig_weight_and_bias_freed_ = false; } Tensor run(const Tensor& input) override; void free_orig_weight_and_bias() override; static c10::intrusive_ptr create_context( Tensor&& weight, std::optional&& bias, std::vector&& padding, std::vector&& stride, std::vector&& dilation, int64_t groups, const std::optional& output_min, const std::optional& output_max); }; class XNNPackTransposeConv2dOpContext final : public TransposeConv2dOpContext { private: ContextConv2D op_context_; // xnnpack convs use indirection buffer. // These buffers need setup at runtime and/or when input // dims change. If we are running the same model on multiple // threads, this can lead to contention where indirection buffer // is being accessed and updated at the same time from two different // threads. std::mutex xnnp_mutex_; public: XNNPackTransposeConv2dOpContext( Tensor&& weight, std::optional&& bias, std::vector&& padding, std::vector&& output_padding, std::vector&& stride, std::vector&& dilation, uint64_t groups, const std::optional& min, const std::optional& max, ContextConv2D&& op_context) : op_context_(std::move(op_context)) { orig_weight_ = std::move(weight); orig_bias_ = std::move(bias); padding_ = std::move(padding); output_padding_ = std::move(output_padding); stride_ = std::move(stride); dilation_ = std::move(dilation); groups_ = groups; output_min_ = min; output_max_ = max; orig_weight_and_bias_freed_ = false; } Tensor run(const Tensor& input) override; void free_orig_weight_and_bias() override; static c10::intrusive_ptr create_context( Tensor&& weight, std::optional&& bias, std::vector&& padding, std::vector&& output_padding, std::vector&& stride, std::vector&& dilation, int64_t groups, const std::optional& output_min, const std::optional& output_max); }; } // namespace at::native::xnnpack #endif /* USE_XNNPACK */