1 #include <torch/csrc/utils/out_types.h>
2
3 namespace torch::utils {
4
5 // Used by python binding codegen to ensure any TensorOptions arguments are
6 // consistent with the out tensor's options
check_out_type_matches(const at::Tensor & result,std::optional<at::ScalarType> scalarType,bool scalarType_is_none,std::optional<at::Layout> layout,std::optional<at::Device> device,bool device_is_none)7 void check_out_type_matches(
8 const at::Tensor& result,
9 std::optional<at::ScalarType> scalarType,
10 bool scalarType_is_none,
11 std::optional<at::Layout> layout,
12 std::optional<at::Device> device,
13 bool device_is_none) {
14 if (scalarType_is_none && !layout && device_is_none) { // common case
15 return;
16 }
17 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
18 if (!scalarType_is_none && result.scalar_type() != scalarType.value()) {
19 AT_ERROR(
20 "dtype ",
21 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
22 *scalarType,
23 " does not match dtype of out parameter (",
24 result.scalar_type(),
25 ")");
26 }
27 if (layout && result.layout() != *layout) {
28 AT_ERROR(
29 "layout ",
30 *layout,
31 " does not match layout of out parameter (",
32 result.layout(),
33 ")");
34 }
35 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
36 if (!device_is_none && result.device().type() != device.value().type()) {
37 AT_ERROR(
38 "device type ",
39 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
40 device->type(),
41 " does not match device type of out parameter (",
42 result.device().type(),
43 ")");
44 }
45 }
46
47 } // namespace torch::utils
48