#pragma once #include #include #include #include #include // This file includes utilities for dynamic_casting done by TensorIterator, see CUDALoops.cuh and Loops.h. // dynamic_casting handles when the types expected by the iterator do not match the types of the arguments // to the function that is being called. // On CUDA, the cast is currently pushed down into the kernel (for performance reasons). // On CPU, there is currently an internal assert that a dynamic_cast is not needed. namespace at::native { // `needs_dynamic_casting` compares the types expected by iterator // (i.e. dtypes of the operands) with the actual type of the arguments // (and returns) of func_t template::arity> struct needs_dynamic_casting { static bool check(TensorIteratorBase& iter) { using traits = function_traits; using cpp_type = typename traits::template arg::type; using cpp_map = c10::CppTypeToScalarType; if (iter.input_dtype(nargs-1) != cpp_map::value) { return true; } return needs_dynamic_casting::check(iter); } }; template struct needs_dynamic_casting { static bool check(TensorIteratorBase& iter) { using traits = function_traits; using cpp_type = typename traits::result_type; // we could assert output numbers are correct here, but checks // (including arity) are currently pushed outside of this struct. if constexpr (std::is_void_v) { return false; } else { return iter.dtype(0) != c10::CppTypeToScalarType::value; } } }; } //namespace at::native