• Home
  • Raw
  • Download

Lines Matching full:names

12 // Returns "Tensor['N', 'C', 'H', 'W']" for a tensor with names ('N', 'C', 'H', 'W').
15 os << "Tensor" << tensor.names(); in toDimnameRepr()
25 const auto names = tensor.names(); in dimname_to_position() local
27 const auto it = std::find(names.begin(), names.end(), dim); in dimname_to_position()
28 TORCH_CHECK(it != names.end(), in dimname_to_position()
31 return std::distance(names.begin(), it); in dimname_to_position()
46 DimnameList names, in report_positional_error() argument
49 …// TODO(zou3519): Can improve message by checking if names are alignable and suggesting workarounds in report_positional_error()
51 "Error when attempting to ", action, " dims ", names, " and dims ", in report_positional_error()
58 DimnameList names, in check_for_misalignment() argument
65 …// TODO(zou3519): Can improve message by checking if names are alignable and suggesting workarounds in check_for_misalignment()
67 "Misaligned dims when attempting to ", action, " dims ", names, " and dims ", in check_for_misalignment()
72 // Assumption: A DimnameList can have no duplicate full names with
75 DimnameList names, in unify_from_right() argument
79 const auto size = std::max(names.size(), other_names.size()); in unify_from_right()
82 auto names_it = names.rbegin(); in unify_from_right()
85 while (names_it != names.rend() || other_it != other_names.rend()) { in unify_from_right()
86 const auto& name = names_it == names.rend() ? wildcard : *names_it; in unify_from_right()
89 // Step 1: Check that the names match in unify_from_right()
92 report_positional_error(name, other_name, names, other_names, action); in unify_from_right()
96 // Step 2: Check that the names are not misaligned in unify_from_right()
98 // Let: N = max(len(names), len(other_names)) in unify_from_right()
99 // K = # of special names among names and other_names. in unify_from_right()
101 check_for_misalignment(name, names, other_names, action); in unify_from_right()
102 check_for_misalignment(other_name, other_names, names, action); in unify_from_right()
105 if (names_it != names.rend()) { in unify_from_right()
127 "Name mismatch: specified out tensor with names ", a, in assert_names_equal()
128 " are not the same as the computed output names ", b, in assert_names_equal()
156 const Tensor& propagate_names(const Tensor& result, DimnameList names, bool validate_names) { in propagate_names() argument
157 propagate_names(result.unsafeGetTensorImpl(), names, validate_names); in propagate_names()
161 TensorImpl* propagate_names(TensorImpl* result, DimnameList names, bool validate_names) { in propagate_names() argument
164 !names.empty(), in propagate_names()
165 "propagate_names: passed in empty names to propagate to result with", in propagate_names()
166 " shape ", result->sizes(), ". Empty names means that name inference did", in propagate_names()
170 impl::internal_set_names_inplace(result, names, validate_names); in propagate_names()
172 assert_names_equal(impl::get_names(result), names); in propagate_names()
181 const auto src_names = src.names(); in propagate_names_except()
237 auto tensor_names = tensor.names(); in compute_squeeze_outnames()
251 auto tensor_names = tensor.names(); in compute_squeeze_outnames()
268 auto tensor_names = tensor.names(); in compute_diagonal_outnames()
293 " would produce output tensor with duplicate names ", in check_feature_names_are_distinct()
298 static int64_t num_batch_dims(DimnameList names) { in num_batch_dims() argument
299 if (names.size() <= 2) { in num_batch_dims()
302 return static_cast<int64_t>(names.size() - 2); in num_batch_dims()
318 // To compute output names, we unify the batch dimensions because those are in compute_matmul_outnames()
321 // After that, we append some names that are equal to the result of the matmul in compute_matmul_outnames()
322 // without batch dimensions. Those names are computed by removing the names in compute_matmul_outnames()
326 // Get the output's batch dimension names in compute_matmul_outnames()
333 // completely contracted away during matmul so we don't take any names from them. in compute_matmul_outnames()
354 auto mv_outnames = compute_matmul_outnames(mat.names(), vec.names()); in propagate_names_for_addmv()
355 return unify_from_right(mv_outnames, bias.names()); in propagate_names_for_addmv()
367 auto mm_outnames = compute_matmul_outnames(m1.names(), m2.names()); in propagate_names_for_addmm()
368 return unify_from_right(mm_outnames, bias.names()); in propagate_names_for_addmm()
407 return unify_from_right(self.names(), other.names()); in compute_broadcast_outnames()
417 auto reference_names = reference_tensor.names(); in broadcast_to_outnames()
418 auto tensor_names = tensor.names(); in broadcast_to_outnames()
434 const auto tensor_names = tensor.names(); in compute_cat_outnames()
450 return compute_matmul_outnames(self.names(), other.names()); in compute_matmul_outnames()
459 const auto self_names = self.names(); in compute_cdist_outnames()
460 const auto other_names = other.names(); in compute_cdist_outnames()
470 // distance values. We propagate the names of the dimension of size M (in self) in compute_cdist_outnames()
486 return compute_matmul_outnames(self.names(), other.names()); in compute_bmm_outnames()
498 auto bmm_names = compute_matmul_outnames(self.names(), other.names()); in compute_baddbmm_outnames()
499 auto baddbmm_names = unify_from_right(bias.names(), bmm_names); in compute_baddbmm_outnames()