// This file defines classes for registering standard lowerings from JIT to TE // IR. #pragma once #include #include #include #include #include namespace torch::jit::tensorexpr { using ArgNone = std::monostate; using BufList = std::vector; using DoubleList = std::vector; using IntList = std::vector; using ArgValue = std::variant< tensorexpr::BufHandle, tensorexpr::VarHandle, double, int64_t, bool, BufList, DoubleList, IntList, std::string, ArgNone>; using NNCLoweringFunction = std::function&, const std::vector&, const std::vector&, const std::optional&, at::Device)>; TORCH_API FunctionSchemaMap& getNNCLoweringRegistry(); TORCH_API NNCLoweringFunction getStandardLoweringFor(const std::string& op); struct RegisterNNCLoweringsFunction { RegisterNNCLoweringsFunction( const std::vector& schemas, const NNCLoweringFunction& fn); }; } // namespace torch::jit::tensorexpr