#ifdef USE_XNNPACK #include #include #include namespace at::native::xnnpack { namespace internal::linear { namespace { // Supports NHWC and NCHW FP32 linear operators. // TODO: Decouple and improve error handling and messages. bool available( const Tensor& weight, const std::optional& bias, const float output_min, const float output_max) { // XNNPACK return xnnpack::available() && // Weight (2 == weight.ndimension()) && (weight.device().is_cpu()) && (kFloat == weight.scalar_type()) && !weight.requires_grad() && // Bias ((bias && bias->defined()) ? ((1 == bias->ndimension()) && (bias->device().is_cpu()) && (kFloat == bias->scalar_type()) && (weight.size(Layout::Filter::output)) == bias->size(0) && !bias->requires_grad()) : true) && // Output Min / Max (output_max > output_min) && true; } // TODO: Decouple and improve error handling and messages. bool usable(const Tensor& input) { // Input return (1 <= input.ndimension()) && (input.device().is_cpu()) && (kFloat == input.scalar_type()) && !input.requires_grad() && true; } Tensor create_and_run( const Tensor& input, const Tensor& weight, const Tensor& bias, const float output_min, const float output_max) { return run( create( weight, bias, output_min, output_max), input); } } // anonymous namespace ContextLinear create( const Tensor& weight, const std::optional& bias, const float output_min, const float output_max) { const Tensor weight_contig = weight.contiguous(); TORCH_CHECK( available( weight_contig, bias, output_min, output_max), "XNNPACK Linear not available! " "Reason: The provided (weight, bias, output_min, output_max) parameters are " "either invalid individually or their combination is not supported by XNNPACK."); xnn_operator_t linear_op{}; const xnn_status create_status = xnn_create_fully_connected_nc_f32( weight_contig.size(Layout::Filter::input), // input_channels weight_contig.size(Layout::Filter::output), // output_channels weight_contig.size(Layout::Filter::input), // input_pixel_stride weight_contig.size(Layout::Filter::output), // output_pixel_stride weight_contig.data_ptr(), // kernel (bias && bias->defined()) ? bias->contiguous().data_ptr() : nullptr, // bias output_min, // output_min output_max, // output_max 0u, // flags nullptr, // xnn_caches_t nullptr, // xnn_weights_cache_t &linear_op); // operator TORCH_CHECK( xnn_status_success == create_status, "xnn_create_fully_connected_nc_f32 failed!"); return ContextLinear( Operator(linear_op), weight_contig.size(Layout::Filter::output) ); } Tensor run( const ContextLinear& context, const Tensor& input) { using namespace internal; // For compatibility with aten::linear auto ip = input; if (input.ndimension() == 1) { ip = input.unsqueeze(0); } const Tensor padded_input = mobile::allocate_padded_contiguous_if_needed( ip, ip.suggest_memory_format()); TORCH_CHECK( usable(padded_input), "XNNPACK Linear not usable! " "Reason: The provided input tensor is either invalid or unsupported by XNNPACK."); const IntArrayRef input_size = padded_input.sizes(); std::vector output_size(input_size.cbegin(), input_size.cend()); output_size.back() = context.output_channels; Tensor output = mobile::empty_with_tail_padding( output_size, padded_input.options().dtype(), padded_input.suggest_memory_format(), padded_input.opt_names()); const xnn_status reshape_status = xnn_reshape_fully_connected_nc_f32( context.op.get(), // operator Layout::ActivationND::batch(padded_input.sizes()), // Batch, caffe2::pthreadpool_()); // threadpool TORCH_CHECK( xnn_status_success == reshape_status, "xnn_reshape_fully_connected_nc_f32 failed!"); const xnn_status setup_status = xnn_setup_fully_connected_nc_f32( context.op.get(), // operator padded_input.data_ptr(), // input output.data_ptr()); // output TORCH_CHECK( xnn_status_success == setup_status, "xnn_setup_fully_connected_nc_f32 failed!"); const xnn_status run_status = xnn_run_operator( context.op.get(), // operator caffe2::pthreadpool_()); // threadpool TORCH_INTERNAL_ASSERT( xnn_status_success == run_status, "xnn_run_operator failed!"); // For compatibility with aten::linear if (input.ndimension() == 1) { output.squeeze_(0); } return output; } c10::intrusive_ptr createLinearClampPrePackOpContext( Tensor weight, std::optional bias, const std::optional& output_min, const std::optional& output_max) { return xnnpack::XNNPackLinearOpContext::create_context( std::move(weight), std::move(bias), output_min, output_max); } Tensor linear_clamp_run( const Tensor& input, const c10::intrusive_ptr& op_context) { return op_context->run(input); } IValue unpack_prepacked_sizes_linear(const IValue& ivalue) { auto op_context = ivalue.toCustomClass(); const auto tuple = op_context->unpack(); const auto& bias = std::get<1>(tuple); return IValue(std::make_tuple( std::get<0>(tuple).sizes(), (bias && bias->defined()) ? at::OptionalIntArrayRef(bias->sizes()) : std::nullopt)); } } // namespace internal::linear bool use_linear( const Tensor& input, const Tensor& weight, const Tensor& bias) { return internal::linear::available( weight, bias, ContextLinear::kMin, ContextLinear::kMax) && internal::linear::usable(input); internal::linear::usable(input); } Tensor linear( const Tensor& input, const Tensor& weight, const Tensor& bias) { return internal::linear::create_and_run( input, weight, bias, ContextLinear::kMin, ContextLinear::kMax); } } // namespace at::native::xnnpack #endif /* USE_XNNPACK */