• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <gtest/gtest.h>
2 
3 #include <torch/torch.h>
4 
5 #include <test/cpp/api/support.h>
6 
7 using namespace torch::indexing;
8 using namespace torch::test;
9 
TEST(TensorIndexingTest,Slice)10 TEST(TensorIndexingTest, Slice) {
11   Slice slice(1, 2, 3);
12   ASSERT_EQ(slice.start(), 1);
13   ASSERT_EQ(slice.stop(), 2);
14   ASSERT_EQ(slice.step(), 3);
15 
16   ASSERT_EQ(c10::str(slice), "1:2:3");
17 }
18 
TEST(TensorIndexingTest,TensorIndex)19 TEST(TensorIndexingTest, TensorIndex) {
20   {
21     std::vector<TensorIndex> indices = {
22         None,
23         "...",
24         Ellipsis,
25         0,
26         true,
27         Slice(1, None, 2),
28         torch::tensor({1, 2})};
29     ASSERT_TRUE(indices[0].is_none());
30     ASSERT_TRUE(indices[1].is_ellipsis());
31     ASSERT_TRUE(indices[2].is_ellipsis());
32     ASSERT_TRUE(indices[3].is_integer());
33     ASSERT_TRUE(indices[3].integer() == 0);
34     ASSERT_TRUE(indices[4].is_boolean());
35     ASSERT_TRUE(indices[4].boolean() == true);
36     ASSERT_TRUE(indices[5].is_slice());
37     ASSERT_TRUE(indices[5].slice().start() == 1);
38     ASSERT_TRUE(indices[5].slice().stop() == INDEX_MAX);
39     ASSERT_TRUE(indices[5].slice().step() == 2);
40     ASSERT_TRUE(indices[6].is_tensor());
41     ASSERT_TRUE(torch::equal(indices[6].tensor(), torch::tensor({1, 2})));
42   }
43 
44   ASSERT_THROWS_WITH(
45       TensorIndex(".."),
46       "Expected \"...\" to represent an ellipsis index, but got \"..\"");
47 
48   {
49     std::vector<TensorIndex> indices = {
50         None, "...", Ellipsis, 0, true, Slice(1, None, 2)};
51     ASSERT_EQ(
52         c10::str(indices),
53         c10::str("(None, ..., ..., 0, true, 1:", INDEX_MAX, ":2)"));
54     ASSERT_EQ(c10::str(indices[0]), "None");
55     ASSERT_EQ(c10::str(indices[1]), "...");
56     ASSERT_EQ(c10::str(indices[2]), "...");
57     ASSERT_EQ(c10::str(indices[3]), "0");
58     ASSERT_EQ(c10::str(indices[4]), "true");
59     ASSERT_EQ(c10::str(indices[5]), c10::str("1:", INDEX_MAX, ":2"));
60   }
61 
62   ASSERT_EQ(
63       c10::str(std::vector<TensorIndex>({Slice()})),
64       c10::str("(0:", INDEX_MAX, ":1)"));
65   ASSERT_EQ(
66       c10::str(std::vector<TensorIndex>({Slice(None, None)})),
67       c10::str("(0:", INDEX_MAX, ":1)"));
68   ASSERT_EQ(
69       c10::str(std::vector<TensorIndex>({Slice(None, None, None)})),
70       c10::str("(0:", INDEX_MAX, ":1)"));
71 
72   ASSERT_EQ(
73       c10::str(std::vector<TensorIndex>({Slice(1, None)})),
74       c10::str("(1:", INDEX_MAX, ":1)"));
75   ASSERT_EQ(
76       c10::str(std::vector<TensorIndex>({Slice(1, None, None)})),
77       c10::str("(1:", INDEX_MAX, ":1)"));
78   ASSERT_EQ(
79       c10::str(std::vector<TensorIndex>({Slice(None, 3)})),
80       c10::str("(0:3:1)"));
81   ASSERT_EQ(
82       c10::str(std::vector<TensorIndex>({Slice(None, 3, None)})),
83       c10::str("(0:3:1)"));
84   ASSERT_EQ(
85       c10::str(std::vector<TensorIndex>({Slice(None, None, 2)})),
86       c10::str("(0:", INDEX_MAX, ":2)"));
87   ASSERT_EQ(
88       c10::str(std::vector<TensorIndex>({Slice(None, None, -1)})),
89       c10::str("(", INDEX_MAX, ":", INDEX_MIN, ":-1)"));
90 
91   ASSERT_EQ(
92       c10::str(std::vector<TensorIndex>({Slice(1, 3)})), c10::str("(1:3:1)"));
93   ASSERT_EQ(
94       c10::str(std::vector<TensorIndex>({Slice(1, None, 2)})),
95       c10::str("(1:", INDEX_MAX, ":2)"));
96   ASSERT_EQ(
97       c10::str(std::vector<TensorIndex>({Slice(1, None, -1)})),
98       c10::str("(1:", INDEX_MIN, ":-1)"));
99   ASSERT_EQ(
100       c10::str(std::vector<TensorIndex>({Slice(None, 3, 2)})),
101       c10::str("(0:3:2)"));
102   ASSERT_EQ(
103       c10::str(std::vector<TensorIndex>({Slice(None, 3, -1)})),
104       c10::str("(", INDEX_MAX, ":3:-1)"));
105 
106   ASSERT_EQ(
107       c10::str(std::vector<TensorIndex>({Slice(1, 3, 2)})),
108       c10::str("(1:3:2)"));
109 }
110 
TEST(TensorIndexingTest,TestNoIndices)111 TEST(TensorIndexingTest, TestNoIndices) {
112   torch::Tensor tensor = torch::randn({20, 20});
113   torch::Tensor value = torch::randn({20, 20});
114   std::vector<TensorIndex> indices;
115 
116   ASSERT_THROWS_WITH(
117       tensor.index({}),
118       "Passing an empty index list to Tensor::index() is not valid syntax");
119   ASSERT_THROWS_WITH(
120       tensor.index_put_({}, 1),
121       "Passing an empty index list to Tensor::index_put_() is not valid syntax");
122   ASSERT_THROWS_WITH(
123       tensor.index_put_({}, value),
124       "Passing an empty index list to Tensor::index_put_() is not valid syntax");
125 
126   ASSERT_THROWS_WITH(
127       tensor.index(indices),
128       "Passing an empty index list to Tensor::index() is not valid syntax");
129   ASSERT_THROWS_WITH(
130       tensor.index_put_(indices, 1),
131       "Passing an empty index list to Tensor::index_put_() is not valid syntax");
132   ASSERT_THROWS_WITH(
133       tensor.index_put_(indices, value),
134       "Passing an empty index list to Tensor::index_put_() is not valid syntax");
135 }
136 
TEST(TensorIndexingTest,TestAdvancedIndexingWithListOfTensor)137 TEST(TensorIndexingTest, TestAdvancedIndexingWithListOfTensor) {
138   {
139     torch::Tensor tensor = torch::randn({20, 20});
140     torch::Tensor index = torch::arange(10, torch::kLong).cpu();
141     torch::Tensor result = at::index(tensor, {index});
142     torch::Tensor result_with_init_list = tensor.index({index});
143     ASSERT_TRUE(result.equal(result_with_init_list));
144   }
145   {
146     torch::Tensor tensor = torch::randn({20, 20});
147     torch::Tensor index = torch::arange(10, torch::kLong).cpu();
148     torch::Tensor result = at::index_put_(tensor, {index}, torch::ones({20}));
149     torch::Tensor result_with_init_list =
150         tensor.index_put_({index}, torch::ones({20}));
151     ASSERT_TRUE(result.equal(result_with_init_list));
152   }
153   {
154     torch::Tensor tensor = torch::randn({20, 20});
155     torch::Tensor index = torch::arange(10, torch::kLong).cpu();
156     torch::Tensor result =
157         at::index_put_(tensor, {index}, torch::ones({1, 20}));
158     torch::Tensor result_with_init_list =
159         tensor.index_put_({index}, torch::ones({1, 20}));
160     ASSERT_TRUE(result.equal(result_with_init_list));
161   }
162 }
163 
TEST(TensorIndexingTest,TestSingleInt)164 TEST(TensorIndexingTest, TestSingleInt) {
165   auto v = torch::randn({5, 7, 3});
166   ASSERT_EQ(v.index({4}).sizes(), torch::IntArrayRef({7, 3}));
167 }
168 
TEST(TensorIndexingTest,TestMultipleInt)169 TEST(TensorIndexingTest, TestMultipleInt) {
170   auto v = torch::randn({5, 7, 3});
171   ASSERT_EQ(v.index({4}).sizes(), torch::IntArrayRef({7, 3}));
172   ASSERT_EQ(v.index({4, Slice(), 1}).sizes(), torch::IntArrayRef({7}));
173 
174   // To show that `.index_put_` works
175   v.index_put_({4, 3, 1}, 0);
176   ASSERT_EQ(v.index({4, 3, 1}).item<double>(), 0);
177 }
178 
TEST(TensorIndexingTest,TestNone)179 TEST(TensorIndexingTest, TestNone) {
180   auto v = torch::randn({5, 7, 3});
181   ASSERT_EQ(v.index({None}).sizes(), torch::IntArrayRef({1, 5, 7, 3}));
182   ASSERT_EQ(v.index({Slice(), None}).sizes(), torch::IntArrayRef({5, 1, 7, 3}));
183   ASSERT_EQ(
184       v.index({Slice(), None, None}).sizes(),
185       torch::IntArrayRef({5, 1, 1, 7, 3}));
186   ASSERT_EQ(v.index({"...", None}).sizes(), torch::IntArrayRef({5, 7, 3, 1}));
187 }
188 
TEST(TensorIndexingTest,TestStep)189 TEST(TensorIndexingTest, TestStep) {
190   auto v = torch::arange(10);
191   assert_tensor_equal(v.index({Slice(None, None, 1)}), v);
192   assert_tensor_equal(
193       v.index({Slice(None, None, 2)}), torch::tensor({0, 2, 4, 6, 8}));
194   assert_tensor_equal(
195       v.index({Slice(None, None, 3)}), torch::tensor({0, 3, 6, 9}));
196   assert_tensor_equal(v.index({Slice(None, None, 11)}), torch::tensor({0}));
197   assert_tensor_equal(v.index({Slice(1, 6, 2)}), torch::tensor({1, 3, 5}));
198 }
199 
TEST(TensorIndexingTest,TestStepAssignment)200 TEST(TensorIndexingTest, TestStepAssignment) {
201   auto v = torch::zeros({4, 4});
202   v.index_put_({0, Slice(1, None, 2)}, torch::tensor({3., 4.}));
203   assert_tensor_equal(v.index({0}), torch::tensor({0., 3., 0., 4.}));
204   assert_tensor_equal(v.index({Slice(1, None)}).sum(), torch::tensor(0));
205 }
206 
TEST(TensorIndexingTest,TestBoolIndices)207 TEST(TensorIndexingTest, TestBoolIndices) {
208   {
209     auto v = torch::randn({5, 7, 3});
210     auto boolIndices =
211         torch::tensor({true, false, true, true, false}, torch::kBool);
212     ASSERT_EQ(v.index({boolIndices}).sizes(), torch::IntArrayRef({3, 7, 3}));
213     assert_tensor_equal(
214         v.index({boolIndices}),
215         torch::stack({v.index({0}), v.index({2}), v.index({3})}));
216   }
217   {
218     auto v = torch::tensor({true, false, true}, torch::kBool);
219     auto boolIndices = torch::tensor({true, false, false}, torch::kBool);
220     auto uint8Indices = torch::tensor({1, 0, 0}, torch::kUInt8);
221 
222     {
223       WarningCapture warnings;
224 
225       ASSERT_EQ(
226           v.index({boolIndices}).sizes(), v.index({uint8Indices}).sizes());
227       assert_tensor_equal(v.index({boolIndices}), v.index({uint8Indices}));
228       assert_tensor_equal(
229           v.index({boolIndices}), torch::tensor({true}, torch::kBool));
230 
231       ASSERT_EQ(
232           count_substr_occurrences(
233               warnings.str(),
234               "indexing with dtype torch.uint8 is now deprecated"),
235           2);
236     }
237   }
238 }
239 
TEST(TensorIndexingTest,TestBoolIndicesAccumulate)240 TEST(TensorIndexingTest, TestBoolIndicesAccumulate) {
241   auto mask = torch::zeros({10}, torch::kBool);
242   auto y = torch::ones({10, 10});
243   y.index_put_({mask}, {y.index({mask})}, /*accumulate=*/true);
244   assert_tensor_equal(y, torch::ones({10, 10}));
245 }
246 
TEST(TensorIndexingTest,TestMultipleBoolIndices)247 TEST(TensorIndexingTest, TestMultipleBoolIndices) {
248   auto v = torch::randn({5, 7, 3});
249   // note: these broadcast together and are transposed to the first dim
250   auto mask1 = torch::tensor({1, 0, 1, 1, 0}, torch::kBool);
251   auto mask2 = torch::tensor({1, 1, 1}, torch::kBool);
252   ASSERT_EQ(
253       v.index({mask1, Slice(), mask2}).sizes(), torch::IntArrayRef({3, 7}));
254 }
255 
TEST(TensorIndexingTest,TestByteMask)256 TEST(TensorIndexingTest, TestByteMask) {
257   {
258     auto v = torch::randn({5, 7, 3});
259     auto mask = torch::tensor({1, 0, 1, 1, 0}, torch::kByte);
260     {
261       WarningCapture warnings;
262 
263       ASSERT_EQ(v.index({mask}).sizes(), torch::IntArrayRef({3, 7, 3}));
264       assert_tensor_equal(v.index({mask}), torch::stack({v[0], v[2], v[3]}));
265 
266       ASSERT_EQ(
267           count_substr_occurrences(
268               warnings.str(),
269               "indexing with dtype torch.uint8 is now deprecated"),
270           2);
271     }
272   }
273   {
274     auto v = torch::tensor({1.});
275     assert_tensor_equal(v.index({v == 0}), torch::randn({0}));
276   }
277 }
278 
TEST(TensorIndexingTest,TestByteMaskAccumulate)279 TEST(TensorIndexingTest, TestByteMaskAccumulate) {
280   auto mask = torch::zeros({10}, torch::kUInt8);
281   auto y = torch::ones({10, 10});
282   {
283     WarningCapture warnings;
284 
285     y.index_put_({mask}, y.index({mask}), /*accumulate=*/true);
286     assert_tensor_equal(y, torch::ones({10, 10}));
287 
288     ASSERT_EQ(
289         count_substr_occurrences(
290             warnings.str(),
291             "indexing with dtype torch.uint8 is now deprecated"),
292         2);
293   }
294 }
295 
TEST(TensorIndexingTest,TestMultipleByteMask)296 TEST(TensorIndexingTest, TestMultipleByteMask) {
297   auto v = torch::randn({5, 7, 3});
298   // note: these broadcast together and are transposed to the first dim
299   auto mask1 = torch::tensor({1, 0, 1, 1, 0}, torch::kByte);
300   auto mask2 = torch::tensor({1, 1, 1}, torch::kByte);
301   {
302     WarningCapture warnings;
303 
304     ASSERT_EQ(
305         v.index({mask1, Slice(), mask2}).sizes(), torch::IntArrayRef({3, 7}));
306 
307     ASSERT_EQ(
308         count_substr_occurrences(
309             warnings.str(),
310             "indexing with dtype torch.uint8 is now deprecated"),
311         2);
312   }
313 }
314 
TEST(TensorIndexingTest,TestByteMask2d)315 TEST(TensorIndexingTest, TestByteMask2d) {
316   auto v = torch::randn({5, 7, 3});
317   auto c = torch::randn({5, 7});
318   int64_t num_ones = (c > 0).sum().item().to<int64_t>();
319   auto r = v.index({c > 0});
320   ASSERT_EQ(r.sizes(), torch::IntArrayRef({num_ones, 3}));
321 }
322 
TEST(TensorIndexingTest,TestIntIndices)323 TEST(TensorIndexingTest, TestIntIndices) {
324   auto v = torch::randn({5, 7, 3});
325   ASSERT_EQ(
326       v.index({torch::tensor({0, 4, 2})}).sizes(),
327       torch::IntArrayRef({3, 7, 3}));
328   ASSERT_EQ(
329       v.index({Slice(), torch::tensor({0, 4, 2})}).sizes(),
330       torch::IntArrayRef({5, 3, 3}));
331   ASSERT_EQ(
332       v.index({Slice(), torch::tensor({{0, 1}, {4, 3}})}).sizes(),
333       torch::IntArrayRef({5, 2, 2, 3}));
334 }
335 
TEST(TensorIndexingTest,TestIntIndices2d)336 TEST(TensorIndexingTest, TestIntIndices2d) {
337   // From the NumPy indexing example
338   auto x = torch::arange(0, 12, torch::kLong).view({4, 3});
339   auto rows = torch::tensor({{0, 0}, {3, 3}});
340   auto columns = torch::tensor({{0, 2}, {0, 2}});
341   assert_tensor_equal(
342       x.index({rows, columns}), torch::tensor({{0, 2}, {9, 11}}));
343 }
344 
TEST(TensorIndexingTest,TestIntIndicesBroadcast)345 TEST(TensorIndexingTest, TestIntIndicesBroadcast) {
346   // From the NumPy indexing example
347   auto x = torch::arange(0, 12, torch::kLong).view({4, 3});
348   auto rows = torch::tensor({0, 3});
349   auto columns = torch::tensor({0, 2});
350   auto result = x.index({rows.index({Slice(), None}), columns});
351   assert_tensor_equal(result, torch::tensor({{0, 2}, {9, 11}}));
352 }
353 
TEST(TensorIndexingTest,TestEmptyIndex)354 TEST(TensorIndexingTest, TestEmptyIndex) {
355   auto x = torch::arange(0, 12).view({4, 3});
356   auto idx = torch::tensor({}, torch::kLong);
357   ASSERT_EQ(x.index({idx}).numel(), 0);
358 
359   // empty assignment should have no effect but not throw an exception
360   auto y = x.clone();
361   y.index_put_({idx}, -1);
362   assert_tensor_equal(x, y);
363 
364   auto mask = torch::zeros({4, 3}, torch::kBool);
365   y.index_put_({mask}, -1);
366   assert_tensor_equal(x, y);
367 }
368 
TEST(TensorIndexingTest,TestEmptyNdimIndex)369 TEST(TensorIndexingTest, TestEmptyNdimIndex) {
370   torch::Device device(torch::kCPU);
371   {
372     auto x = torch::randn({5}, device);
373     assert_tensor_equal(
374         torch::empty({0, 2}, device),
375         x.index({torch::empty(
376             {0, 2}, torch::TensorOptions(torch::kInt64).device(device))}));
377   }
378   {
379     auto x = torch::randn({2, 3, 4, 5}, device);
380     assert_tensor_equal(
381         torch::empty({2, 0, 6, 4, 5}, device),
382         x.index(
383             {Slice(),
384              torch::empty(
385                  {0, 6}, torch::TensorOptions(torch::kInt64).device(device))}));
386   }
387   {
388     auto x = torch::empty({10, 0});
389     ASSERT_EQ(
390         x.index({torch::tensor({1, 2})}).sizes(), torch::IntArrayRef({2, 0}));
391     ASSERT_EQ(
392         x.index(
393              {torch::tensor({}, torch::kLong), torch::tensor({}, torch::kLong)})
394             .sizes(),
395         torch::IntArrayRef({0}));
396     ASSERT_THROWS_WITH(
397         x.index({Slice(), torch::tensor({0, 1})}), "for dimension with size 0");
398   }
399 }
400 
TEST(TensorIndexingTest,TestEmptyNdimIndex_CUDA)401 TEST(TensorIndexingTest, TestEmptyNdimIndex_CUDA) {
402   torch::Device device(torch::kCUDA);
403   {
404     auto x = torch::randn({5}, device);
405     assert_tensor_equal(
406         torch::empty({0, 2}, device),
407         x.index({torch::empty(
408             {0, 2}, torch::TensorOptions(torch::kInt64).device(device))}));
409   }
410   {
411     auto x = torch::randn({2, 3, 4, 5}, device);
412     assert_tensor_equal(
413         torch::empty({2, 0, 6, 4, 5}, device),
414         x.index(
415             {Slice(),
416              torch::empty(
417                  {0, 6}, torch::TensorOptions(torch::kInt64).device(device))}));
418   }
419 }
420 
TEST(TensorIndexingTest,TestEmptyNdimIndexBool)421 TEST(TensorIndexingTest, TestEmptyNdimIndexBool) {
422   torch::Device device(torch::kCPU);
423   auto x = torch::randn({5}, device);
424   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
425   ASSERT_THROW(
426       x.index({torch::empty(
427           {0, 2}, torch::TensorOptions(torch::kUInt8).device(device))}),
428       c10::Error);
429 }
430 
TEST(TensorIndexingTest,TestEmptyNdimIndexBool_CUDA)431 TEST(TensorIndexingTest, TestEmptyNdimIndexBool_CUDA) {
432   torch::Device device(torch::kCUDA);
433   auto x = torch::randn({5}, device);
434   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
435   ASSERT_THROW(
436       x.index({torch::empty(
437           {0, 2}, torch::TensorOptions(torch::kUInt8).device(device))}),
438       c10::Error);
439 }
440 
TEST(TensorIndexingTest,TestEmptySlice)441 TEST(TensorIndexingTest, TestEmptySlice) {
442   torch::Device device(torch::kCPU);
443   auto x = torch::randn({2, 3, 4, 5}, device);
444   auto y = x.index({Slice(), Slice(), Slice(), 1});
445   auto z = y.index({Slice(), Slice(1, 1), Slice()});
446   ASSERT_EQ(z.sizes(), torch::IntArrayRef({2, 0, 4}));
447   // this isn't technically necessary, but matches NumPy stride calculations.
448   ASSERT_EQ(z.strides(), torch::IntArrayRef({60, 20, 5}));
449   ASSERT_TRUE(z.is_contiguous());
450 }
451 
TEST(TensorIndexingTest,TestEmptySlice_CUDA)452 TEST(TensorIndexingTest, TestEmptySlice_CUDA) {
453   torch::Device device(torch::kCUDA);
454   auto x = torch::randn({2, 3, 4, 5}, device);
455   auto y = x.index({Slice(), Slice(), Slice(), 1});
456   auto z = y.index({Slice(), Slice(1, 1), Slice()});
457   ASSERT_EQ(z.sizes(), torch::IntArrayRef({2, 0, 4}));
458   // this isn't technically necessary, but matches NumPy stride calculations.
459   ASSERT_EQ(z.strides(), torch::IntArrayRef({60, 20, 5}));
460   ASSERT_TRUE(z.is_contiguous());
461 }
462 
TEST(TensorIndexingTest,TestIndexGetitemCopyBoolsSlices)463 TEST(TensorIndexingTest, TestIndexGetitemCopyBoolsSlices) {
464   auto true_tensor = torch::tensor(1, torch::kUInt8);
465   auto false_tensor = torch::tensor(0, torch::kUInt8);
466 
467   std::vector<torch::Tensor> tensors = {torch::randn({2, 3}), torch::tensor(3)};
468 
469   for (auto& a : tensors) {
470     ASSERT_NE(a.data_ptr(), a.index({true}).data_ptr());
471     {
472       std::vector<int64_t> sizes = {0};
473       sizes.insert(sizes.end(), a.sizes().begin(), a.sizes().end());
474       assert_tensor_equal(torch::empty(sizes), a.index({false}));
475     }
476     ASSERT_NE(a.data_ptr(), a.index({true_tensor}).data_ptr());
477     {
478       std::vector<int64_t> sizes = {0};
479       sizes.insert(sizes.end(), a.sizes().begin(), a.sizes().end());
480       assert_tensor_equal(torch::empty(sizes), a.index({false_tensor}));
481     }
482     ASSERT_EQ(a.data_ptr(), a.index({None}).data_ptr());
483     ASSERT_EQ(a.data_ptr(), a.index({"..."}).data_ptr());
484   }
485 }
486 
TEST(TensorIndexingTest,TestIndexSetitemBoolsSlices)487 TEST(TensorIndexingTest, TestIndexSetitemBoolsSlices) {
488   auto true_tensor = torch::tensor(1, torch::kUInt8);
489   auto false_tensor = torch::tensor(0, torch::kUInt8);
490 
491   std::vector<torch::Tensor> tensors = {torch::randn({2, 3}), torch::tensor(3)};
492 
493   for (auto& a : tensors) {
494     // prefix with a 1,1, to ensure we are compatible with numpy which cuts off
495     // prefix 1s (some of these ops already prefix a 1 to the size)
496     auto neg_ones = torch::ones_like(a) * -1;
497     auto neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0);
498     a.index_put_({true}, neg_ones_expanded);
499     assert_tensor_equal(a, neg_ones);
500     a.index_put_({false}, 5);
501     assert_tensor_equal(a, neg_ones);
502     a.index_put_({true_tensor}, neg_ones_expanded * 2);
503     assert_tensor_equal(a, neg_ones * 2);
504     a.index_put_({false_tensor}, 5);
505     assert_tensor_equal(a, neg_ones * 2);
506     a.index_put_({None}, neg_ones_expanded * 3);
507     assert_tensor_equal(a, neg_ones * 3);
508     a.index_put_({"..."}, neg_ones_expanded * 4);
509     assert_tensor_equal(a, neg_ones * 4);
510     if (a.dim() == 0) {
511       // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
512       ASSERT_THROW(a.index_put_({Slice()}, neg_ones_expanded * 5), c10::Error);
513     }
514   }
515 }
516 
TEST(TensorIndexingTest,TestIndexScalarWithBoolMask)517 TEST(TensorIndexingTest, TestIndexScalarWithBoolMask) {
518   torch::Device device(torch::kCPU);
519 
520   auto a = torch::tensor(1, device);
521   auto uintMask =
522       torch::tensor(true, torch::TensorOptions(torch::kUInt8).device(device));
523   auto boolMask =
524       torch::tensor(true, torch::TensorOptions(torch::kBool).device(device));
525   assert_tensor_equal(a.index({uintMask}), a.index({boolMask}));
526   ASSERT_EQ(a.index({uintMask}).dtype(), a.index({boolMask}).dtype());
527 
528   a = torch::tensor(true, torch::TensorOptions(torch::kBool).device(device));
529   assert_tensor_equal(a.index({uintMask}), a.index({boolMask}));
530   ASSERT_EQ(a.index({uintMask}).dtype(), a.index({boolMask}).dtype());
531 }
532 
TEST(TensorIndexingTest,TestIndexScalarWithBoolMask_CUDA)533 TEST(TensorIndexingTest, TestIndexScalarWithBoolMask_CUDA) {
534   torch::Device device(torch::kCUDA);
535 
536   auto a = torch::tensor(1, device);
537   auto uintMask =
538       torch::tensor(true, torch::TensorOptions(torch::kUInt8).device(device));
539   auto boolMask =
540       torch::tensor(true, torch::TensorOptions(torch::kBool).device(device));
541   assert_tensor_equal(a.index({uintMask}), a.index({boolMask}));
542   ASSERT_EQ(a.index({uintMask}).dtype(), a.index({boolMask}).dtype());
543 
544   a = torch::tensor(true, torch::TensorOptions(torch::kBool).device(device));
545   assert_tensor_equal(a.index({uintMask}), a.index({boolMask}));
546   ASSERT_EQ(a.index({uintMask}).dtype(), a.index({boolMask}).dtype());
547 }
548 
TEST(TensorIndexingTest,TestSetitemExpansionError)549 TEST(TensorIndexingTest, TestSetitemExpansionError) {
550   auto true_tensor = torch::tensor(true);
551   auto a = torch::randn({2, 3});
552   // check prefix with  non-1s doesn't work
553   std::vector<int64_t> tensor_sizes{5, 1};
554   tensor_sizes.insert(tensor_sizes.end(), a.sizes().begin(), a.sizes().end());
555   auto a_expanded = a.expand(tensor_sizes);
556   // NumPy: ValueError
557   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
558   ASSERT_THROW(a.index_put_({true}, a_expanded), c10::Error);
559   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
560   ASSERT_THROW(a.index_put_({true_tensor}, a_expanded), c10::Error);
561 }
562 
TEST(TensorIndexingTest,TestGetitemScalars)563 TEST(TensorIndexingTest, TestGetitemScalars) {
564   auto zero = torch::tensor(0, torch::kInt64);
565   auto one = torch::tensor(1, torch::kInt64);
566 
567   // non-scalar indexed with scalars
568   auto a = torch::randn({2, 3});
569   assert_tensor_equal(a.index({0}), a.index({zero}));
570   assert_tensor_equal(a.index({0}).index({1}), a.index({zero}).index({one}));
571   assert_tensor_equal(a.index({0, 1}), a.index({zero, one}));
572   assert_tensor_equal(a.index({0, one}), a.index({zero, 1}));
573 
574   // indexing by a scalar should slice (not copy)
575   ASSERT_EQ(a.index({0, 1}).data_ptr(), a.index({zero, one}).data_ptr());
576   ASSERT_EQ(a.index({1}).data_ptr(), a.index({one.to(torch::kInt)}).data_ptr());
577   ASSERT_EQ(
578       a.index({1}).data_ptr(), a.index({one.to(torch::kShort)}).data_ptr());
579 
580   // scalar indexed with scalar
581   auto r = torch::randn({});
582   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
583   ASSERT_THROW(r.index({Slice()}), c10::Error);
584   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
585   ASSERT_THROW(r.index({zero}), c10::Error);
586   assert_tensor_equal(r, r.index({"..."}));
587 }
588 
TEST(TensorIndexingTest,TestSetitemScalars)589 TEST(TensorIndexingTest, TestSetitemScalars) {
590   auto zero = torch::tensor(0, torch::kInt64);
591 
592   // non-scalar indexed with scalars
593   auto a = torch::randn({2, 3});
594   auto a_set_with_number = a.clone();
595   auto a_set_with_scalar = a.clone();
596   auto b = torch::randn({3});
597 
598   a_set_with_number.index_put_({0}, b);
599   a_set_with_scalar.index_put_({zero}, b);
600   assert_tensor_equal(a_set_with_number, a_set_with_scalar);
601   a.index_put_({1, zero}, 7.7);
602   ASSERT_TRUE(a.index({1, 0}).allclose(torch::tensor(7.7)));
603 
604   // scalar indexed with scalars
605   auto r = torch::randn({});
606   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
607   ASSERT_THROW(r.index_put_({Slice()}, 8.8), c10::Error);
608   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
609   ASSERT_THROW(r.index_put_({zero}, 8.8), c10::Error);
610   r.index_put_({"..."}, 9.9);
611   ASSERT_TRUE(r.allclose(torch::tensor(9.9)));
612 }
613 
TEST(TensorIndexingTest,TestBasicAdvancedCombined)614 TEST(TensorIndexingTest, TestBasicAdvancedCombined) {
615   // From the NumPy indexing example
616   auto x = torch::arange(0, 12).to(torch::kLong).view({4, 3});
617   assert_tensor_equal(
618       x.index({Slice(1, 2), Slice(1, 3)}),
619       x.index({Slice(1, 2), torch::tensor({1, 2})}));
620   assert_tensor_equal(
621       x.index({Slice(1, 2), Slice(1, 3)}), torch::tensor({{4, 5}}));
622 
623   // Check that it is a copy
624   {
625     auto unmodified = x.clone();
626     x.index({Slice(1, 2), torch::tensor({1, 2})}).zero_();
627     assert_tensor_equal(x, unmodified);
628   }
629 
630   // But assignment should modify the original
631   {
632     auto unmodified = x.clone();
633     x.index_put_({Slice(1, 2), torch::tensor({1, 2})}, 0);
634     assert_tensor_not_equal(x, unmodified);
635   }
636 }
637 
TEST(TensorIndexingTest,TestIntAssignment)638 TEST(TensorIndexingTest, TestIntAssignment) {
639   {
640     auto x = torch::arange(0, 4).to(torch::kLong).view({2, 2});
641     x.index_put_({1}, 5);
642     assert_tensor_equal(x, torch::tensor({{0, 1}, {5, 5}}));
643   }
644 
645   {
646     auto x = torch::arange(0, 4).to(torch::kLong).view({2, 2});
647     x.index_put_({1}, torch::arange(5, 7).to(torch::kLong));
648     assert_tensor_equal(x, torch::tensor({{0, 1}, {5, 6}}));
649   }
650 }
651 
TEST(TensorIndexingTest,TestByteTensorAssignment)652 TEST(TensorIndexingTest, TestByteTensorAssignment) {
653   auto x = torch::arange(0., 16).to(torch::kFloat).view({4, 4});
654   auto b = torch::tensor({true, false, true, false}, torch::kByte);
655   auto value = torch::tensor({3., 4., 5., 6.});
656 
657   {
658     WarningCapture warnings;
659 
660     x.index_put_({b}, value);
661 
662     ASSERT_EQ(
663         count_substr_occurrences(
664             warnings.str(),
665             "indexing with dtype torch.uint8 is now deprecated"),
666         1);
667   }
668 
669   assert_tensor_equal(x.index({0}), value);
670   assert_tensor_equal(x.index({1}), torch::arange(4, 8).to(torch::kLong));
671   assert_tensor_equal(x.index({2}), value);
672   assert_tensor_equal(x.index({3}), torch::arange(12, 16).to(torch::kLong));
673 }
674 
TEST(TensorIndexingTest,TestVariableSlicing)675 TEST(TensorIndexingTest, TestVariableSlicing) {
676   auto x = torch::arange(0, 16).view({4, 4});
677   auto indices = torch::tensor({0, 1}, torch::kInt);
678   int i = indices[0].item<int>();
679   int j = indices[1].item<int>();
680   assert_tensor_equal(x.index({Slice(i, j)}), x.index({Slice(0, 1)}));
681 }
682 
TEST(TensorIndexingTest,TestEllipsisTensor)683 TEST(TensorIndexingTest, TestEllipsisTensor) {
684   auto x = torch::arange(0, 9).to(torch::kLong).view({3, 3});
685   auto idx = torch::tensor({0, 2});
686   assert_tensor_equal(
687       x.index({"...", idx}), torch::tensor({{0, 2}, {3, 5}, {6, 8}}));
688   assert_tensor_equal(
689       x.index({idx, "..."}), torch::tensor({{0, 1, 2}, {6, 7, 8}}));
690 }
691 
TEST(TensorIndexingTest,TestOutOfBoundIndex)692 TEST(TensorIndexingTest, TestOutOfBoundIndex) {
693   auto x = torch::arange(0, 100).view({2, 5, 10});
694   ASSERT_THROWS_WITH(
695       x.index({0, 5}), "index 5 is out of bounds for dimension 1 with size 5");
696   ASSERT_THROWS_WITH(
697       x.index({4, 5}), "index 4 is out of bounds for dimension 0 with size 2");
698   ASSERT_THROWS_WITH(
699       x.index({0, 1, 15}),
700       "index 15 is out of bounds for dimension 2 with size 10");
701   ASSERT_THROWS_WITH(
702       x.index({Slice(), Slice(), 12}),
703       "index 12 is out of bounds for dimension 2 with size 10");
704 }
705 
TEST(TensorIndexingTest,TestZeroDimIndex)706 TEST(TensorIndexingTest, TestZeroDimIndex) {
707   auto x = torch::tensor(10);
708 
709   auto runner = [&]() -> torch::Tensor {
710     std::cout << x.index({0}) << std::endl;
711     return x.index({0});
712   };
713 
714   ASSERT_THROWS_WITH(runner(), "invalid index");
715 }
716 
717 // The tests below are from NumPy test_indexing.py with some modifications to
718 // make them compatible with libtorch. It's licensed under the BDS license
719 // below:
720 //
721 // Copyright (c) 2005-2017, NumPy Developers.
722 // All rights reserved.
723 //
724 // Redistribution and use in source and binary forms, with or without
725 // modification, are permitted provided that the following conditions are
726 // met:
727 //
728 //     * Redistributions of source code must retain the above copyright
729 //        notice, this list of conditions and the following disclaimer.
730 //
731 //     * Redistributions in binary form must reproduce the above
732 //        copyright notice, this list of conditions and the following
733 //        disclaimer in the documentation and/or other materials provided
734 //        with the distribution.
735 //
736 //     * Neither the name of the NumPy Developers nor the names of any
737 //        contributors may be used to endorse or promote products derived
738 //        from this software without specific prior written permission.
739 //
740 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
741 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
742 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
743 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
744 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
745 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
746 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
747 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
748 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
749 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
750 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
751 
TEST(NumpyTests,TestNoneIndex)752 TEST(NumpyTests, TestNoneIndex) {
753   // `None` index adds newaxis
754   auto a = torch::tensor({1, 2, 3});
755   ASSERT_EQ(a.index({None}).dim(), a.dim() + 1);
756 }
757 
TEST(NumpyTests,TestEmptyFancyIndex)758 TEST(NumpyTests, TestEmptyFancyIndex) {
759   // Empty list index creates an empty array
760   auto a = torch::tensor({1, 2, 3});
761   assert_tensor_equal(
762       a.index({torch::tensor({}, torch::kLong)}), torch::tensor({}));
763 
764   auto b = torch::tensor({}).to(torch::kLong);
765   assert_tensor_equal(
766       a.index({torch::tensor({}, torch::kLong)}),
767       torch::tensor({}, torch::kLong));
768 
769   b = torch::tensor({}).to(torch::kFloat);
770   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
771   ASSERT_THROW(a.index({b}), c10::Error);
772 }
773 
TEST(NumpyTests,TestEllipsisIndex)774 TEST(NumpyTests, TestEllipsisIndex) {
775   auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
776   ASSERT_FALSE(a.index({"..."}).is_same(a));
777   assert_tensor_equal(a.index({"..."}), a);
778   // `a[...]` was `a` in numpy <1.9.
779   ASSERT_EQ(a.index({"..."}).data_ptr(), a.data_ptr());
780 
781   // Slicing with ellipsis can skip an
782   // arbitrary number of dimensions
783   assert_tensor_equal(a.index({0, "..."}), a.index({0}));
784   assert_tensor_equal(a.index({0, "..."}), a.index({0, Slice()}));
785   assert_tensor_equal(a.index({"...", 0}), a.index({Slice(), 0}));
786 
787   // In NumPy, slicing with ellipsis results in a 0-dim array. In PyTorch
788   // we don't have separate 0-dim arrays and scalars.
789   assert_tensor_equal(a.index({0, "...", 1}), torch::tensor(2));
790 
791   // Assignment with `Ellipsis` on 0-d arrays
792   auto b = torch::tensor(1);
793   b.index_put_({Ellipsis}, 2);
794   ASSERT_EQ(b.item<int64_t>(), 2);
795 }
796 
TEST(NumpyTests,TestSingleIntIndex)797 TEST(NumpyTests, TestSingleIntIndex) {
798   // Single integer index selects one row
799   auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
800 
801   assert_tensor_equal(a.index({0}), torch::tensor({1, 2, 3}));
802   assert_tensor_equal(a.index({-1}), torch::tensor({7, 8, 9}));
803 
804   // Index out of bounds produces IndexError
805   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
806   ASSERT_THROW(a.index({1 << 30}), c10::Error);
807   // NOTE: According to the standard
808   // (http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2017/p0543r0.html), for
809   // signed integers, if during the evaluation of an expression, the result is
810   // not mathematically defined or not in the range of representable values for
811   // its type, the behavior is undefined. Therefore, there is no way to check
812   // for index overflow case because it might not throw exception.
813   // ASSERT_THROW(a(1 << 64), c10::Error);
814 }
815 
TEST(NumpyTests,TestSingleBoolIndex)816 TEST(NumpyTests, TestSingleBoolIndex) {
817   // Single boolean index
818   auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
819 
820   assert_tensor_equal(a.index({true}), a.index({None}));
821   assert_tensor_equal(a.index({false}), a.index({None}).index({Slice(0, 0)}));
822 }
823 
TEST(NumpyTests,TestBooleanShapeMismatch)824 TEST(NumpyTests, TestBooleanShapeMismatch) {
825   auto arr = torch::ones({5, 4, 3});
826 
827   auto index = torch::tensor({true});
828   ASSERT_THROWS_WITH(arr.index({index}), "mask");
829 
830   index = torch::tensor({false, false, false, false, false, false});
831   ASSERT_THROWS_WITH(arr.index({index}), "mask");
832 
833   {
834     WarningCapture warnings;
835 
836     index = torch::empty({4, 4}, torch::kByte).zero_();
837     ASSERT_THROWS_WITH(arr.index({index}), "mask");
838     ASSERT_THROWS_WITH(arr.index({Slice(), index}), "mask");
839 
840     ASSERT_EQ(
841         count_substr_occurrences(
842             warnings.str(),
843             "indexing with dtype torch.uint8 is now deprecated"),
844         2);
845   }
846 }
847 
TEST(NumpyTests,TestBooleanIndexingOnedim)848 TEST(NumpyTests, TestBooleanIndexingOnedim) {
849   // Indexing a 2-dimensional array with
850   // boolean array of length one
851   auto a = torch::tensor({{0., 0., 0.}});
852   auto b = torch::tensor({true});
853   assert_tensor_equal(a.index({b}), a);
854   // boolean assignment
855   a.index_put_({b}, 1.);
856   assert_tensor_equal(a, torch::tensor({{1., 1., 1.}}));
857 }
858 
TEST(NumpyTests,TestBooleanAssignmentValueMismatch)859 TEST(NumpyTests, TestBooleanAssignmentValueMismatch) {
860   // A boolean assignment should fail when the shape of the values
861   // cannot be broadcast to the subscription. (see also gh-3458)
862   auto a = torch::arange(0, 4);
863 
864   auto f = [](torch::Tensor a, std::vector<int64_t> v) -> void {
865     a.index_put_({a > -1}, torch::tensor(v));
866   };
867 
868   ASSERT_THROWS_WITH(f(a, {}), "shape mismatch");
869   ASSERT_THROWS_WITH(f(a, {1, 2, 3}), "shape mismatch");
870   ASSERT_THROWS_WITH(f(a.index({Slice(None, 1)}), {1, 2, 3}), "shape mismatch");
871 }
872 
TEST(NumpyTests,TestBooleanIndexingTwodim)873 TEST(NumpyTests, TestBooleanIndexingTwodim) {
874   // Indexing a 2-dimensional array with
875   // 2-dimensional boolean array
876   auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
877   auto b = torch::tensor(
878       {{true, false, true}, {false, true, false}, {true, false, true}});
879   assert_tensor_equal(a.index({b}), torch::tensor({1, 3, 5, 7, 9}));
880   assert_tensor_equal(a.index({b.index({1})}), torch::tensor({{4, 5, 6}}));
881   assert_tensor_equal(a.index({b.index({0})}), a.index({b.index({2})}));
882 
883   // boolean assignment
884   a.index_put_({b}, 0);
885   assert_tensor_equal(a, torch::tensor({{0, 2, 0}, {4, 0, 6}, {0, 8, 0}}));
886 }
887 
TEST(NumpyTests,TestBooleanIndexingWeirdness)888 TEST(NumpyTests, TestBooleanIndexingWeirdness) {
889   // Weird boolean indexing things
890   auto a = torch::ones({2, 3, 4});
891   ASSERT_EQ(
892       a.index({false, true, "..."}).sizes(), torch::IntArrayRef({0, 2, 3, 4}));
893   assert_tensor_equal(
894       torch::ones({1, 2}),
895       a.index(
896           {true,
897            torch::tensor({0, 1}),
898            true,
899            true,
900            torch::tensor({1}),
901            torch::tensor({{2}})}));
902   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
903   ASSERT_THROW(a.index({false, torch::tensor({0, 1}), "..."}), c10::Error);
904 }
905 
TEST(NumpyTests,TestBooleanIndexingWeirdnessTensors)906 TEST(NumpyTests, TestBooleanIndexingWeirdnessTensors) {
907   // Weird boolean indexing things
908   auto false_tensor = torch::tensor(false);
909   auto true_tensor = torch::tensor(true);
910   auto a = torch::ones({2, 3, 4});
911   ASSERT_EQ(
912       a.index({false, true, "..."}).sizes(), torch::IntArrayRef({0, 2, 3, 4}));
913   assert_tensor_equal(
914       torch::ones({1, 2}),
915       a.index(
916           {true_tensor,
917            torch::tensor({0, 1}),
918            true_tensor,
919            true_tensor,
920            torch::tensor({1}),
921            torch::tensor({{2}})}));
922   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
923   ASSERT_THROW(
924       a.index({false_tensor, torch::tensor({0, 1}), "..."}), c10::Error);
925 }
926 
TEST(NumpyTests,TestBooleanIndexingAlldims)927 TEST(NumpyTests, TestBooleanIndexingAlldims) {
928   auto true_tensor = torch::tensor(true);
929   auto a = torch::ones({2, 3});
930   ASSERT_EQ(a.index({true, true}).sizes(), torch::IntArrayRef({1, 2, 3}));
931   ASSERT_EQ(
932       a.index({true_tensor, true_tensor}).sizes(),
933       torch::IntArrayRef({1, 2, 3}));
934 }
935 
TEST(NumpyTests,TestBooleanListIndexing)936 TEST(NumpyTests, TestBooleanListIndexing) {
937   // Indexing a 2-dimensional array with
938   // boolean lists
939   auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
940   auto b = torch::tensor({true, false, false});
941   auto c = torch::tensor({true, true, false});
942   assert_tensor_equal(a.index({b}), torch::tensor({{1, 2, 3}}));
943   assert_tensor_equal(a.index({b, b}), torch::tensor({1}));
944   assert_tensor_equal(a.index({c}), torch::tensor({{1, 2, 3}, {4, 5, 6}}));
945   assert_tensor_equal(a.index({c, c}), torch::tensor({1, 5}));
946 }
947 
TEST(NumpyTests,TestEverythingReturnsViews)948 TEST(NumpyTests, TestEverythingReturnsViews) {
949   // Before `...` would return a itself.
950   auto a = torch::tensor({5});
951 
952   ASSERT_FALSE(a.is_same(a.index({"..."})));
953   ASSERT_FALSE(a.is_same(a.index({Slice()})));
954 }
955 
TEST(NumpyTests,TestBroaderrorsIndexing)956 TEST(NumpyTests, TestBroaderrorsIndexing) {
957   auto a = torch::zeros({5, 5});
958   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
959   ASSERT_THROW(
960       a.index({torch::tensor({0, 1}), torch::tensor({0, 1, 2})}), c10::Error);
961   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
962   ASSERT_THROW(
963       a.index_put_({torch::tensor({0, 1}), torch::tensor({0, 1, 2})}, 0),
964       c10::Error);
965 }
966 
TEST(NumpyTests,TestTrivialFancyOutOfBounds)967 TEST(NumpyTests, TestTrivialFancyOutOfBounds) {
968   auto a = torch::zeros({5});
969   auto ind = torch::ones({20}, torch::kInt64);
970   ind.index_put_({-1}, 10);
971   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
972   ASSERT_THROW(a.index({ind}), c10::Error);
973   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
974   ASSERT_THROW(a.index_put_({ind}, 0), c10::Error);
975   ind = torch::ones({20}, torch::kInt64);
976   ind.index_put_({0}, 11);
977   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
978   ASSERT_THROW(a.index({ind}), c10::Error);
979   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
980   ASSERT_THROW(a.index_put_({ind}, 0), c10::Error);
981 }
982 
TEST(NumpyTests,TestIndexIsLarger)983 TEST(NumpyTests, TestIndexIsLarger) {
984   // Simple case of fancy index broadcasting of the index.
985   auto a = torch::zeros({5, 5});
986   a.index_put_(
987       {torch::tensor({{0}, {1}, {2}}), torch::tensor({0, 1, 2})},
988       torch::tensor({2., 3., 4.}));
989 
990   ASSERT_TRUE(
991       (a.index({Slice(None, 3), Slice(None, 3)}) == torch::tensor({2., 3., 4.}))
992           .all()
993           .item<bool>());
994 }
995 
TEST(NumpyTests,TestBroadcastSubspace)996 TEST(NumpyTests, TestBroadcastSubspace) {
997   auto a = torch::zeros({100, 100});
998   auto v = torch::arange(0., 100).index({Slice(), None});
999   auto b = torch::arange(99, -1, -1).to(torch::kLong);
1000   a.index_put_({b}, v);
1001   auto expected = b.to(torch::kDouble).unsqueeze(1).expand({100, 100});
1002   assert_tensor_equal(a, expected);
1003 }
1004