• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <torch/csrc/utils/out_types.h>
2 
3 namespace torch::utils {
4 
5 // Used by python binding codegen to ensure any TensorOptions arguments are
6 // consistent with the out tensor's options
check_out_type_matches(const at::Tensor & result,std::optional<at::ScalarType> scalarType,bool scalarType_is_none,std::optional<at::Layout> layout,std::optional<at::Device> device,bool device_is_none)7 void check_out_type_matches(
8     const at::Tensor& result,
9     std::optional<at::ScalarType> scalarType,
10     bool scalarType_is_none,
11     std::optional<at::Layout> layout,
12     std::optional<at::Device> device,
13     bool device_is_none) {
14   if (scalarType_is_none && !layout && device_is_none) { // common case
15     return;
16   }
17   // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
18   if (!scalarType_is_none && result.scalar_type() != scalarType.value()) {
19     AT_ERROR(
20         "dtype ",
21         // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
22         *scalarType,
23         " does not match dtype of out parameter (",
24         result.scalar_type(),
25         ")");
26   }
27   if (layout && result.layout() != *layout) {
28     AT_ERROR(
29         "layout ",
30         *layout,
31         " does not match layout of out parameter (",
32         result.layout(),
33         ")");
34   }
35   // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
36   if (!device_is_none && result.device().type() != device.value().type()) {
37     AT_ERROR(
38         "device type ",
39         // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
40         device->type(),
41         " does not match device type of out parameter (",
42         result.device().type(),
43         ")");
44   }
45 }
46 
47 } // namespace torch::utils
48