1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Navdeep Jaitly <ndjaitly@google.com and
5 // Benoit Steiner <benoit.steiner.goog@gmail.com>
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11 #include "main.h"
12
13 #include <Eigen/CXX11/Tensor>
14
15 using Eigen::Tensor;
16 using Eigen::array;
17
18 template <int DataLayout>
test_simple_reverse()19 static void test_simple_reverse()
20 {
21 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
22 tensor.setRandom();
23
24 array<bool, 4> dim_rev;
25 dim_rev[0] = false;
26 dim_rev[1] = true;
27 dim_rev[2] = true;
28 dim_rev[3] = false;
29
30 Tensor<float, 4, DataLayout> reversed_tensor;
31 reversed_tensor = tensor.reverse(dim_rev);
32
33 VERIFY_IS_EQUAL(reversed_tensor.dimension(0), 2);
34 VERIFY_IS_EQUAL(reversed_tensor.dimension(1), 3);
35 VERIFY_IS_EQUAL(reversed_tensor.dimension(2), 5);
36 VERIFY_IS_EQUAL(reversed_tensor.dimension(3), 7);
37
38 for (int i = 0; i < 2; ++i) {
39 for (int j = 0; j < 3; ++j) {
40 for (int k = 0; k < 5; ++k) {
41 for (int l = 0; l < 7; ++l) {
42 VERIFY_IS_EQUAL(tensor(i,j,k,l), reversed_tensor(i,2-j,4-k,l));
43 }
44 }
45 }
46 }
47
48 dim_rev[0] = true;
49 dim_rev[1] = false;
50 dim_rev[2] = false;
51 dim_rev[3] = false;
52
53 reversed_tensor = tensor.reverse(dim_rev);
54
55 VERIFY_IS_EQUAL(reversed_tensor.dimension(0), 2);
56 VERIFY_IS_EQUAL(reversed_tensor.dimension(1), 3);
57 VERIFY_IS_EQUAL(reversed_tensor.dimension(2), 5);
58 VERIFY_IS_EQUAL(reversed_tensor.dimension(3), 7);
59
60
61 for (int i = 0; i < 2; ++i) {
62 for (int j = 0; j < 3; ++j) {
63 for (int k = 0; k < 5; ++k) {
64 for (int l = 0; l < 7; ++l) {
65 VERIFY_IS_EQUAL(tensor(i,j,k,l), reversed_tensor(1-i,j,k,l));
66 }
67 }
68 }
69 }
70
71 dim_rev[0] = true;
72 dim_rev[1] = false;
73 dim_rev[2] = false;
74 dim_rev[3] = true;
75
76 reversed_tensor = tensor.reverse(dim_rev);
77
78 VERIFY_IS_EQUAL(reversed_tensor.dimension(0), 2);
79 VERIFY_IS_EQUAL(reversed_tensor.dimension(1), 3);
80 VERIFY_IS_EQUAL(reversed_tensor.dimension(2), 5);
81 VERIFY_IS_EQUAL(reversed_tensor.dimension(3), 7);
82
83
84 for (int i = 0; i < 2; ++i) {
85 for (int j = 0; j < 3; ++j) {
86 for (int k = 0; k < 5; ++k) {
87 for (int l = 0; l < 7; ++l) {
88 VERIFY_IS_EQUAL(tensor(i,j,k,l), reversed_tensor(1-i,j,k,6-l));
89 }
90 }
91 }
92 }
93 }
94
95
96 template <int DataLayout>
test_expr_reverse(bool LValue)97 static void test_expr_reverse(bool LValue)
98 {
99 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
100 tensor.setRandom();
101
102 array<bool, 4> dim_rev;
103 dim_rev[0] = false;
104 dim_rev[1] = true;
105 dim_rev[2] = false;
106 dim_rev[3] = true;
107
108 Tensor<float, 4, DataLayout> expected(2, 3, 5, 7);
109 if (LValue) {
110 expected.reverse(dim_rev) = tensor;
111 } else {
112 expected = tensor.reverse(dim_rev);
113 }
114
115 Tensor<float, 4, DataLayout> result(2,3,5,7);
116
117 array<ptrdiff_t, 4> src_slice_dim;
118 src_slice_dim[0] = 2;
119 src_slice_dim[1] = 3;
120 src_slice_dim[2] = 1;
121 src_slice_dim[3] = 7;
122 array<ptrdiff_t, 4> src_slice_start;
123 src_slice_start[0] = 0;
124 src_slice_start[1] = 0;
125 src_slice_start[2] = 0;
126 src_slice_start[3] = 0;
127 array<ptrdiff_t, 4> dst_slice_dim = src_slice_dim;
128 array<ptrdiff_t, 4> dst_slice_start = src_slice_start;
129
130 for (int i = 0; i < 5; ++i) {
131 if (LValue) {
132 result.slice(dst_slice_start, dst_slice_dim).reverse(dim_rev) =
133 tensor.slice(src_slice_start, src_slice_dim);
134 } else {
135 result.slice(dst_slice_start, dst_slice_dim) =
136 tensor.slice(src_slice_start, src_slice_dim).reverse(dim_rev);
137 }
138 src_slice_start[2] += 1;
139 dst_slice_start[2] += 1;
140 }
141
142 VERIFY_IS_EQUAL(result.dimension(0), 2);
143 VERIFY_IS_EQUAL(result.dimension(1), 3);
144 VERIFY_IS_EQUAL(result.dimension(2), 5);
145 VERIFY_IS_EQUAL(result.dimension(3), 7);
146
147 for (int i = 0; i < expected.dimension(0); ++i) {
148 for (int j = 0; j < expected.dimension(1); ++j) {
149 for (int k = 0; k < expected.dimension(2); ++k) {
150 for (int l = 0; l < expected.dimension(3); ++l) {
151 VERIFY_IS_EQUAL(result(i,j,k,l), expected(i,j,k,l));
152 }
153 }
154 }
155 }
156
157 dst_slice_start[2] = 0;
158 result.setRandom();
159 for (int i = 0; i < 5; ++i) {
160 if (LValue) {
161 result.slice(dst_slice_start, dst_slice_dim).reverse(dim_rev) =
162 tensor.slice(dst_slice_start, dst_slice_dim);
163 } else {
164 result.slice(dst_slice_start, dst_slice_dim) =
165 tensor.reverse(dim_rev).slice(dst_slice_start, dst_slice_dim);
166 }
167 dst_slice_start[2] += 1;
168 }
169
170 for (int i = 0; i < expected.dimension(0); ++i) {
171 for (int j = 0; j < expected.dimension(1); ++j) {
172 for (int k = 0; k < expected.dimension(2); ++k) {
173 for (int l = 0; l < expected.dimension(3); ++l) {
174 VERIFY_IS_EQUAL(result(i,j,k,l), expected(i,j,k,l));
175 }
176 }
177 }
178 }
179 }
180
181
test_cxx11_tensor_reverse()182 void test_cxx11_tensor_reverse()
183 {
184 CALL_SUBTEST(test_simple_reverse<ColMajor>());
185 CALL_SUBTEST(test_simple_reverse<RowMajor>());
186 CALL_SUBTEST(test_expr_reverse<ColMajor>(true));
187 CALL_SUBTEST(test_expr_reverse<RowMajor>(true));
188 CALL_SUBTEST(test_expr_reverse<ColMajor>(false));
189 CALL_SUBTEST(test_expr_reverse<RowMajor>(false));
190 }
191