1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_ 17 #define TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_ 18 19 #define EIGEN_USE_THREADS 20 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 21 #define EIGEN_USE_GPU 22 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 23 24 #include "absl/container/flat_hash_map.h" 25 #include "absl/strings/str_split.h" 26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 27 #include "tensorflow/core/framework/kernel_def_builder.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/register_types.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_shape.h" 32 #include "tensorflow/core/framework/tensor_types.h" 33 #include "tensorflow/core/kernels/fill_functor.h" 34 #include "tensorflow/core/kernels/linalg/einsum_op.h" 35 #include "tensorflow/core/kernels/matmul_op_impl.h" 36 #include "tensorflow/core/kernels/reduction_ops_common.h" 37 #include "tensorflow/core/kernels/transpose_functor.h" 38 #include "tensorflow/core/lib/core/errors.h" 39 #include "tensorflow/core/lib/core/status.h" 40 #include "tensorflow/core/lib/gtl/inlined_vector.h" 41 #include "tensorflow/core/lib/math/math_util.h" 42 #include "tensorflow/core/platform/types.h" 43 #include "tensorflow/core/profiler/lib/traceme.h" 44 #include "tensorflow/core/util/einsum_op_util.h" 45 46 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 47 #include "tensorflow/core/kernels/reduction_ops_common_gpu.h" 48 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 49 50 namespace tensorflow { 51 52 using CPUDevice = Eigen::ThreadPoolDevice; 53 using GPUDevice = Eigen::GpuDevice; 54 55 using ShapeVec = gtl::InlinedVector<int64, 8>; 56 using Labels = gtl::InlinedVector<int, 8>; 57 using OperandLabels = gtl::InlinedVector<Labels, 2>; 58 using LabelCounts = gtl::InlinedVector<int, 8>; 59 using OperandLabelCounts = gtl::InlinedVector<LabelCounts, 2>; 60 using LabelToDimSizes = gtl::InlinedVector<int64, 8>; 61 62 // Dummy axis label used to denote an ellipsis in an input or output subscript. 63 constexpr int kEllipsisLabel = -1; 64 65 struct EinsumHelper { 66 // Each dimension is categorized into exactly one of five types based on 67 // whether its corresponding label is present in the input and/or the output 68 // subscripts. 69 enum DimensionType { 70 // Batch dimensions are those present in two inputs as well as the output. 71 // They are part of the batch dimensions during Tensor contraction. 72 // Such dimensions may be broadcasting dimensions (those mapping to 73 // ellipsis) 74 // or explicit batch dimensions corresponding to named axis labels. 75 kBroadcasting = 0, 76 kBatch = 1, 77 // Free dimensions are present in exactly one of the inputs, and also the 78 // output. These are non-contracted axes in the Tensor contraction. 79 kFree = 2, 80 // Contract dimensions are present in two inputs, but not the output. These 81 // dimensions are contracted in Tensor contraction. 82 kContract = 3, 83 // Reduce dimensions are present in exactly one input; and not in the output 84 // and are summed over prior to Tensor contraction. 85 kReduce = 4, 86 }; 87 88 // Returns the DimensionType given whether the corresponding label is present 89 // in exactly one input subscript (is_unique) and whether it is absent from 90 // the output subscripts (is_removed). Does not handle broadcasting 91 // dimensions. GetDimensionTypeEinsumHelper92 static DimensionType GetDimensionType(bool is_removed, bool is_unique) { 93 if (!is_removed && !is_unique) 94 return kBatch; 95 else if (!is_removed && is_unique) 96 return kFree; 97 else if (is_removed && !is_unique) 98 return kContract; 99 else // is_removed && is_unique 100 return kReduce; 101 } 102 103 // Maps the character labels to consecutive integers. MapToLabelsEinsumHelper104 static void MapToLabels(const string& subscript, Labels* labels, 105 absl::flat_hash_map<char, int>* label_mapping) { 106 for (int i = 0; i < subscript.size(); ++i) { 107 const char label_char = subscript[i]; 108 if (label_char == '.') { 109 labels->push_back(kEllipsisLabel); 110 i += 2; // Skip next 2 characters as well. 111 continue; 112 } 113 if (!label_mapping->contains(label_char)) { 114 const int next_label = label_mapping->size(); 115 (*label_mapping)[label_char] = next_label; 116 } 117 const int mapped_label = (*label_mapping)[label_char]; 118 labels->push_back(mapped_label); 119 } 120 } 121 122 // Parses and validates the equation and the input shapes. Single character 123 // labels are integerized and we populate input and output label subscripts 124 // and corresponding counts. Also create the mapping from (named) labels to 125 // their DimensionType. ParseEquationEinsumHelper126 static Status ParseEquation(const string& equation, 127 OperandLabels* input_labels, 128 Labels* output_labels, 129 std::vector<DimensionType>* label_types, 130 OperandLabelCounts* input_label_counts, 131 LabelCounts* output_label_counts, 132 gtl::InlinedVector<bool, 2>* input_has_ellipsis, 133 bool* output_has_ellipsis) { 134 gtl::InlinedVector<string, 2> input_str; 135 string output_str; 136 TF_RETURN_IF_ERROR(ParseEinsumEquation(equation, &input_str, &output_str)); 137 138 // Temporary map from single character labels to (consecutive) integer 139 // labels. 140 absl::flat_hash_map<char, int> label_mapping; 141 int num_inputs = input_str.size(); 142 input_labels->resize(num_inputs); 143 144 // Map from single characters to integer labels. 145 for (int i = 0; i < num_inputs; ++i) { 146 MapToLabels(input_str[i], &input_labels->at(i), &label_mapping); 147 } 148 MapToLabels(output_str, output_labels, &label_mapping); 149 150 // Compute counts for input and output labels. 151 int num_labels = label_mapping.size(); 152 input_label_counts->resize(num_inputs); 153 input_has_ellipsis->resize(num_inputs); 154 for (int i = 0; i < num_inputs; ++i) { 155 input_label_counts->at(i).resize(num_labels); 156 for (const int label : input_labels->at(i)) { 157 if (label != kEllipsisLabel) 158 input_label_counts->at(i)[label] += 1; 159 else 160 input_has_ellipsis->at(i) = true; 161 } 162 } 163 output_label_counts->resize(num_labels); 164 for (const int label : *output_labels) { 165 if (label != kEllipsisLabel) 166 output_label_counts->at(label) += 1; 167 else 168 *output_has_ellipsis = true; 169 } 170 171 // Map each label to a unique DimensionType. 172 label_types->resize(num_labels); 173 for (int label = 0; label < num_labels; ++label) { 174 if (label == kEllipsisLabel) continue; 175 bool removed = (*output_label_counts)[label] == 0; 176 bool unique = num_inputs == 1 || (*input_label_counts)[0][label] == 0 || 177 (*input_label_counts)[1][label] == 0; 178 (*label_types)[label] = GetDimensionType(removed, unique); 179 } 180 return Status::OK(); 181 } 182 183 // Insert new (unnamed) broadcasting labels at the location of ellipsis. InsertBroadcastLabelsEinsumHelper184 static void InsertBroadcastLabels(int num_bcast_dims, int num_named_labels, 185 int ellipsis_axis, Labels* labels, 186 LabelCounts* label_counts) { 187 labels->erase(labels->begin() + ellipsis_axis); 188 labels->insert(labels->begin() + ellipsis_axis, num_bcast_dims, 0); 189 std::iota(labels->begin() + ellipsis_axis, 190 labels->begin() + ellipsis_axis + num_bcast_dims, 191 num_named_labels); 192 // Increment label counts. Since these are new labels, the count is set 193 // to 1. 194 label_counts->resize(num_named_labels + num_bcast_dims, 1); 195 } 196 197 // Record and validate the label to dimension mapping. Must be a named 198 // (non-broadcasting) label as broadcasting labels don't have a fixed 199 // dimension. RecordLabelToDimensionEinsumHelper200 static Status RecordLabelToDimension(const int label, const int axis, 201 const Tensor& input, 202 LabelToDimSizes* label_to_dim_sizes) { 203 const int64 input_dim = input.dim_size(axis); 204 // We know that label_to_dim_sizes has the size to accommodate named labels. 205 if (label_to_dim_sizes->at(label) != 0 && 206 label_to_dim_sizes->at(label) != input_dim) { 207 return errors::InvalidArgument( 208 "Expected dimension ", label_to_dim_sizes->at(label), " at axis ", 209 axis, " of the input shaped ", input.shape().DebugString(), 210 " but got dimension ", input_dim); 211 } 212 (*label_to_dim_sizes)[label] = input_dim; 213 return Status::OK(); 214 } 215 216 // Validate input dimensions and populate unnamed labels and their label 217 // counts. ProcessDimensionsEinsumHelper218 static Status ProcessDimensions( 219 const OpInputList& inputs, 220 const gtl::InlinedVector<bool, 2>& input_has_ellipsis, 221 const bool output_has_ellipsis, OperandLabels* input_labels, 222 Labels* output_labels, std::vector<DimensionType>* label_types, 223 OperandLabelCounts* input_label_counts, LabelCounts* output_label_counts, 224 LabelToDimSizes* label_to_dim_sizes) { 225 if (inputs.size() != input_labels->size()) { 226 return errors::InvalidArgument("Expected ", input_labels->size(), 227 " inputs but got: ", inputs.size()); 228 } 229 const int num_inputs = inputs.size(); 230 231 // We infer the number of broadcasting dimensions by taking the maximum rank 232 // among the broadcasting subshapes of the input. 233 int max_bcast_dims = 0; 234 const int num_named_labels = label_types->size(); 235 label_to_dim_sizes->resize(num_named_labels); 236 for (int i = 0; i < num_inputs; ++i) { 237 Labels* labels = &(*input_labels)[i]; 238 239 if (!input_has_ellipsis[i]) { 240 if (inputs[i].dims() != labels->size()) { 241 return errors::InvalidArgument("Expected input ", i, " to have rank ", 242 labels->size(), 243 " but got: ", inputs[i].dims()); 244 } 245 for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { 246 const int label = (*labels)[label_idx]; 247 TF_RETURN_IF_ERROR(RecordLabelToDimension(label, label_idx, inputs[i], 248 label_to_dim_sizes)); 249 } 250 continue; 251 } 252 253 // Input has an ellipsis. 254 if (inputs[i].dims() + 1 < labels->size()) { 255 return errors::InvalidArgument( 256 "Expected input ", i, " to have rank at least ", labels->size() - 1, 257 " but got: ", inputs[i].dims()); 258 } 259 int ellipsis_axis = -1; 260 const int num_bcast_dims = inputs[i].dims() - labels->size() + 1; 261 for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { 262 const int label = (*labels)[label_idx]; 263 if (label == kEllipsisLabel) { 264 ellipsis_axis = label_idx; 265 continue; 266 } 267 // Current label is not an ellipsis. 268 const int axis = 269 label_idx + (ellipsis_axis == -1 ? 0 : num_bcast_dims - 1); 270 TF_RETURN_IF_ERROR( 271 RecordLabelToDimension(label, axis, inputs[i], label_to_dim_sizes)); 272 } 273 // Found an ellipsis. Replace 'kEllipsisLabel' with broadcasting 274 // dimensions. 275 if (ellipsis_axis != -1) { 276 InsertBroadcastLabels(num_bcast_dims, num_named_labels, ellipsis_axis, 277 labels, &input_label_counts->at(i)); 278 max_bcast_dims = std::max(max_bcast_dims, num_bcast_dims); 279 } 280 } 281 if (!absl::c_linear_search(input_has_ellipsis, true) && 282 !output_has_ellipsis) { 283 return Status::OK(); 284 } 285 // Insert broadcasting dimensions in the output labels. 286 auto it = 287 std::find(output_labels->begin(), output_labels->end(), kEllipsisLabel); 288 if (it != output_labels->end()) { 289 const int ellipsis_axis = it - output_labels->begin(); 290 InsertBroadcastLabels(max_bcast_dims, num_named_labels, ellipsis_axis, 291 output_labels, output_label_counts); 292 } else if (max_bcast_dims > 0) { 293 return errors::InvalidArgument( 294 "Output contains ", max_bcast_dims, 295 " broadcasting dimension(s) but no ellipsis " 296 "(...) was found in the output subscripts."); 297 } 298 // Populate DimensionType for the new broadcasting labels. 299 label_types->resize(num_named_labels + max_bcast_dims, kBroadcasting); 300 return Status::OK(); 301 } 302 303 // Permutes the labels according to the given permutation. PermuteLabelsEinsumHelper304 static void PermuteLabels(const std::vector<int>& permutation, 305 Labels* labels) { 306 Labels permuted_labels(labels->size()); 307 for (int i = 0; i < labels->size(); ++i) { 308 permuted_labels[i] = (*labels)[permutation[i]]; 309 } 310 labels->swap(permuted_labels); 311 } 312 313 // Returns a reshaped input Tensor. The underlying buffer is not copied. CopyFromEinsumHelper314 static Status CopyFrom(const Tensor& input, const TensorShape& shape, 315 Tensor* output) { 316 if (output->CopyFrom(input, shape)) return Status::OK(); 317 return errors::Internal( 318 "Encountered error while reshaping a Tensor of shape ", 319 input.shape().DebugString(), " to shape ", shape.DebugString()); 320 } 321 322 // Returns whether transposing would be a no-op; whether input has rank < 2 or 323 // the permutation is the identity permutation. ShouldTransposeEinsumHelper324 static bool ShouldTranspose(const TensorShape& input_shape, 325 const std::vector<int>& permutation) { 326 if (input_shape.dims() < 2) return false; 327 for (int i = 0; i < permutation.size(); ++i) { 328 if (permutation[i] != i) return true; 329 } 330 return false; 331 } 332 333 // Transpose the input given a permutation. Returns a reference to the input 334 // if transposing is not necessary. 335 template <typename Device, typename T> TransposeOperandEinsumHelper336 static Status TransposeOperand(OpKernelContext* ctx, const Tensor& input, 337 const std::vector<int>& permutation, 338 Tensor* output) { 339 if (!ShouldTranspose(input.shape(), permutation)) { 340 return CopyFrom(input, input.shape(), output); 341 } 342 TensorShape transposed_shape; 343 for (int i = 0; i < input.dims(); ++i) { 344 transposed_shape.AddDim(input.dim_size(permutation[i])); 345 } 346 // For empty Tensors, just change the shape. E.g. we may need to transpose 347 // from shape [1, 0, 5] to [5, 1, 0]. 348 if (input.NumElements() == 0) { 349 return CopyFrom(input, transposed_shape, output); 350 } 351 TF_RETURN_IF_ERROR( 352 ctx->allocate_temp(DataTypeToEnum<T>::value, transposed_shape, output)); 353 const Device& device = ctx->eigen_device<Device>(); 354 TF_RETURN_IF_ERROR(DoTranspose(device, input, permutation, output)); 355 return Status::OK(); 356 } 357 358 // If there are repeated labels in either the input or output, then this 359 // strides the input (e.g. iii->i) or inflates it (e.g. i->iii), respectively. 360 template <typename Device, typename T> StrideOrInflateEinsumHelper361 static Status StrideOrInflate(OpKernelContext* ctx, const Tensor& input, 362 const Labels& labels, 363 const LabelCounts& label_counts, 364 const bool should_inflate, Tensor* output) { 365 // Return early if there are no repeated indices. 366 if (absl::c_all_of(label_counts, [](int c) { return c <= 1; })) { 367 return CopyFrom(input, input.shape(), output); 368 } 369 // We reshape so that each repeated label is compressed to one dimension. 370 // E.g. For iiij -> ij, The shape [3, 3, 3, 5] would be compressed to [27, 371 // 5]. Striding appropriately (in this case with strides 14 (=1+3+9) and 1) 372 // recovers the generalized diagonal of shape [3, 5]. 373 ShapeVec reshape; 374 ShapeVec strides; 375 // Strided and inflated shapes correspond to input and output shapes, 376 // respectively, should_inflate is true (vice-versa if should_inflate is 377 // false). E.g. they are [3, 5] and [3, 3, 3, 5] in the above example. 378 ShapeVec strided_shape; 379 ShapeVec inflated_shape; 380 for (int label : labels) { 381 const int count = label_counts[label]; 382 const int current_axis = 383 should_inflate ? strided_shape.size() : inflated_shape.size(); 384 const int64 dim = input.dim_size(current_axis); 385 strided_shape.push_back(dim); 386 inflated_shape.insert(inflated_shape.end(), count, dim); 387 const int64 reshape_dim = MathUtil::IPow(dim, count); 388 reshape.push_back(reshape_dim); 389 // While taking the d-diagonal in a rank k Tensor, we take d 390 // equally-spaced elements including the first and last element. Then, (k 391 // - 1) * stride = d^k - 1, or, stride = (d^k - 1)/(d - 1). 392 const int64 stride = 393 (dim > 1 && count > 1) ? (reshape_dim - 1) / (dim - 1) : 1; 394 strides.push_back(stride); 395 } 396 397 TensorShape output_shape = 398 TensorShape(should_inflate ? inflated_shape : strided_shape); 399 TF_RETURN_IF_ERROR( 400 ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output)); 401 const Device& device = ctx->eigen_device<Device>(); 402 switch (reshape.size()) { 403 #define NDIMS_CASE(N) \ 404 case N: { \ 405 if (should_inflate) { \ 406 auto output_map = output->shaped<T, N>(reshape); \ 407 auto input_map = input.shaped<T, N>(strided_shape); \ 408 functor::InflateFunctor<Device, T, N>()( \ 409 device, input_map, TensorShape(strides).AsEigenDSizes<N>(), \ 410 output_map); \ 411 } else { \ 412 auto input_map = input.shaped<T, N>(reshape); \ 413 auto output_map = output->shaped<T, N>(strided_shape); \ 414 functor::StrideFunctor<Device, T, N>()( \ 415 device, input_map, TensorShape(strides).AsEigenDSizes<N>(), \ 416 output_map); \ 417 } \ 418 } break; 419 NDIMS_CASE(1); 420 NDIMS_CASE(2); 421 NDIMS_CASE(3); 422 NDIMS_CASE(4); 423 NDIMS_CASE(5); 424 NDIMS_CASE(6); 425 default: 426 return errors::Unimplemented( 427 "Unsupported rank: ", reshape.size(), 428 " while handling repeated indices. Up to rank 6 is supported."); 429 #undef NDIMS_CASE 430 } 431 return Status::OK(); 432 } 433 434 // Returns true if the input dimensions are already sorted in the order 435 // [batch, contract, free, reduce]. Used to implement an optimization to avoid 436 // an extra transpose and instead uses (adj_x and adj_y) in BatchMatMul. ShouldSwapFreeAndContractEinsumHelper437 static bool ShouldSwapFreeAndContract( 438 const Labels& labels, const std::vector<DimensionType>& label_types) { 439 // Check that ordering is according to dimension type, with the role of 440 // free and contract dimensions swapped. 441 gtl::InlinedVector<int, 5> remap = {0, 1, 3, 2, 4}; 442 for (int i = 0; i + 1 < labels.size(); ++i) { 443 const int dimtype_a = remap[label_types[labels[i]]]; 444 const int dimtype_b = remap[label_types[labels[i + 1]]]; 445 if (dimtype_a > dimtype_b || 446 (dimtype_a == dimtype_b && labels[i] > labels[i + 1])) { 447 return false; 448 } 449 } 450 return true; 451 } 452 453 template <typename Device, typename T> ReduceOperandEinsumHelper454 static Status ReduceOperand(OpKernelContext* ctx, const Tensor& input, 455 const std::vector<DimensionType>& label_types, 456 const LabelCounts& label_counts, Labels* labels, 457 Labels* free_labels, bool* swap_free_and_contract, 458 Tensor* output) { 459 // Find the permutation to transpose the input dimensions in the order of 460 // DimensionType; i.e. batch, free, contract and reduce dimensions. This 461 // makes it more convenient to invoke Reduce/Contract operations. 462 std::vector<int> permutation(input.dims()); 463 absl::c_iota(permutation, 0); 464 Tensor input_transposed; 465 // Check if we can avoid the transpose. We need to flip the adj_x (or adj_y) 466 // flag during BatchMatMul. This is an extra optimization not necessary for 467 // correctness. 468 if (ShouldSwapFreeAndContract(*labels, label_types)) { 469 *swap_free_and_contract = true; 470 } else { 471 absl::c_sort(permutation, [&](int i, int j) { 472 int label_i = (*labels)[i]; 473 int label_j = (*labels)[j]; 474 return std::tie(label_types[label_i], label_i) < 475 std::tie(label_types[label_j], label_j); 476 }); 477 } 478 // Transpose the input so that DimensionTypes are in order. 479 TF_RETURN_IF_ERROR(TransposeOperand<Device, T>(ctx, input, permutation, 480 &input_transposed)); 481 PermuteLabels(permutation, labels); 482 483 // Take the generalized diagonal for dimensions with repeated axis labels. 484 Tensor input_deduped; 485 labels->erase(std::unique(labels->begin(), labels->end()), labels->end()); 486 TF_RETURN_IF_ERROR( 487 StrideOrInflate<Device, T>(ctx, input_transposed, *labels, label_counts, 488 false /* should_inflate */, &input_deduped)); 489 490 // Reshape denotes the rank-5 shape [broadcast, batch, free, contract, 491 // reduce] where we've compacted the dimensions of each DimensionType. 492 gtl::InlinedVector<int64, 5> reshape(5, 1); 493 // The output shape is [batch shape] + [free size, contract size] 494 // That is, the batch shape is preserved (for broadcasting while 495 // contracting) while the free dims and contract dims are compressed to one 496 // dimension each. 497 TensorShape output_shape; 498 for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { 499 const int label = labels->at(label_idx); 500 int64 dim = input_deduped.dim_size(label_idx); 501 if (label_types[label] == kBroadcasting || label_types[label] == kBatch) { 502 output_shape.AddDim(dim); 503 } else if (label_types[label] == kFree) { 504 free_labels->push_back(label); 505 } 506 reshape[label_types[label]] *= dim; 507 } 508 if (*swap_free_and_contract) std::swap(reshape[kFree], reshape[kContract]); 509 output_shape.AddDim(reshape[kFree]); 510 output_shape.AddDim(reshape[kContract]); 511 512 if (reshape[kReduce] == 1) { // No need to actually reduce. 513 return CopyFrom(input_deduped, output_shape, output); 514 } 515 TF_RETURN_IF_ERROR( 516 ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output)); 517 using Reducer = Eigen::internal::SumReducer<T>; 518 using Index = typename TTypes<T>::Tensor::Index; 519 // Reduce along the last axis (i.e axis 1) of the rank-2 Tensor. 520 const int64 output_size = reshape[kBroadcasting] * reshape[kBatch] * 521 reshape[kFree] * reshape[kContract]; 522 functor::ReduceFunctor<Device, Reducer>::Reduce( 523 ctx, output->shaped<T, 1>({output_size}), 524 const_cast<const Tensor&>(input_deduped) 525 .shaped<T, 2>({output_size, reshape[kReduce]}), 526 Eigen::array<Index, 1>({1}), Reducer()); 527 return Status::OK(); 528 } 529 530 // Reshapes a Tensor of shape [b0,b1...bk,N,M] to [prod(b0,b1...bk),N,M]. ReshapeToRank3EinsumHelper531 static Status ReshapeToRank3(const Tensor& input, int batch_size, 532 Tensor* output) { 533 const int rank = input.dims(); 534 TensorShape output_shape = {batch_size, input.dim_size(rank - 2), 535 input.dim_size(rank - 1)}; 536 return CopyFrom(input, output_shape, output); 537 } 538 539 // Contracts the inputs along the last axis (or the second last if the 540 // corresponding value of swap_free_and_contract is true). The batch 541 // dimensions are broadcast to the output shape. 542 // TODO(anudhyan): BatchMatMul might devolve into a component-wise 543 // multiplication when the matrix shape is [1,1]; in this case BatchMatMul 544 // functor would be very inefficient. The functor should detect if this is the 545 // case and perform componentwise multiplication functor instead. 546 template <typename Device, typename T> ContractOperandsEinsumHelper547 static Status ContractOperands(OpKernelContext* ctx, 548 absl::Span<const Tensor> inputs, 549 absl::Span<const bool> swap_free_and_contract, 550 Tensor* output) { 551 if (inputs.size() == 1) 552 return CopyFrom(inputs[0], inputs[0].shape(), output); 553 MatMulBCast bcast(inputs[0].shape().dim_sizes(), 554 inputs[1].shape().dim_sizes()); 555 if (!bcast.IsValid()) { 556 return errors::InvalidArgument( 557 "Invalid broadcasting dimensions: ", inputs[0].shape().DebugString(), 558 " vs. ", inputs[1].shape().DebugString()); 559 } 560 Tensor lhs; 561 TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[0], bcast.x_batch_size(), &lhs)); 562 Tensor rhs; 563 TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[1], bcast.y_batch_size(), &rhs)); 564 TensorShape output_shape = bcast.output_batch_shape(); 565 for (int i = 0; i < inputs.size(); ++i) { 566 const int64 free_axis = 567 inputs[i].dims() - (swap_free_and_contract[i] ? 1 : 2); 568 output_shape.AddDim(inputs[i].dim_size(free_axis)); 569 } 570 bool trans_x = swap_free_and_contract[0]; 571 bool trans_y = !swap_free_and_contract[1]; 572 TF_RETURN_IF_ERROR( 573 ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output)); 574 if (lhs.NumElements() == 0 || rhs.NumElements() == 0) { 575 functor::SetZeroFunctor<Device, T> set_zero; 576 set_zero(ctx->eigen_device<Device>(), output->flat<T>()); 577 return Status::OK(); 578 } 579 Tensor output_reshaped; 580 TF_RETURN_IF_ERROR( 581 ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped)); 582 LaunchBatchMatMul<Device, T>::Launch(ctx, lhs, rhs, /*adj_x=*/false, 583 /*adj_y=*/false, trans_x, trans_y, 584 bcast, &output_reshaped); 585 return Status::OK(); 586 } 587 }; 588 589 template <typename Device, typename T> 590 class EinsumOp : public OpKernel { 591 public: EinsumOp(OpKernelConstruction * c)592 explicit EinsumOp(OpKernelConstruction* c) : OpKernel(c) { 593 OP_REQUIRES_OK(c, c->GetAttr("equation", &equation_)); 594 OP_REQUIRES_OK( 595 c, EinsumHelper::ParseEquation( 596 equation_, &input_labels_, &output_labels_, &label_types_, 597 &input_label_counts_, &output_label_counts_, 598 &input_has_ellipsis_, &output_has_ellipsis_)); 599 } 600 Compute(OpKernelContext * ctx)601 void Compute(OpKernelContext* ctx) override { 602 OpInputList inputs; 603 OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &inputs)); 604 605 OperandLabels input_labels(input_labels_); 606 Labels output_labels(output_labels_); 607 std::vector<EinsumHelper::DimensionType> label_types(label_types_); 608 OperandLabelCounts input_label_counts(input_label_counts_); 609 LabelCounts output_label_counts(output_label_counts_); 610 LabelToDimSizes label_to_dim_sizes; 611 612 OP_REQUIRES_OK(ctx, EinsumHelper::ProcessDimensions( 613 inputs, input_has_ellipsis_, output_has_ellipsis_, 614 &input_labels, &output_labels, &label_types, 615 &input_label_counts, &output_label_counts, 616 &label_to_dim_sizes)); 617 618 // The reduction phase (a) sums across reduction dimensions, (b) takes 619 // generalized diagonals, and (c) reshapes it into shape 620 // [(broadcasting) batch shape] + [F,C] 621 // where F and C denote the total (compacted) size of free and contract 622 // dimensions, respectively. 623 const int num_inputs = inputs.size(); 624 OperandLabels free_labels(num_inputs); 625 gtl::InlinedVector<Tensor, 2> inputs_reduced(num_inputs); 626 gtl::InlinedVector<bool, 2> swap_free_and_contract(num_inputs); 627 for (int i = 0; i < num_inputs; ++i) { 628 OP_REQUIRES_OK(ctx, 629 EinsumHelper::ReduceOperand<Device, T>( 630 ctx, inputs[i], label_types, input_label_counts[i], 631 &input_labels[i], &free_labels[i], 632 &swap_free_and_contract[i], &inputs_reduced[i])); 633 } 634 635 // After reduction, the inputs should be reshaped to Tensors suitable for 636 // contraction. If num_inputs is 1, the reduced input is simply forwarded to 637 // the output. 638 Tensor contraction_output_reshaped; 639 OP_REQUIRES_OK(ctx, EinsumHelper::ContractOperands<Device, T>( 640 ctx, inputs_reduced, swap_free_and_contract, 641 &contraction_output_reshaped)); 642 643 // Copy the batch labels from the contraction output. Recover the batch 644 // shape, which may have been broadcasted. 645 TensorShape result_shape = contraction_output_reshaped.shape(); 646 result_shape.RemoveLastDims(2); 647 648 int num_labels = label_types.size(); 649 Labels result_labels; 650 // All batch dimensions should be present in the contracted result. First 651 // the broadcasting dimensions, then the named batch dimensions. 652 for (int label = 0; label < num_labels; ++label) { 653 if (label_types[label] == EinsumHelper::kBroadcasting) 654 result_labels.push_back(label); 655 } 656 for (int label = 0; label < num_labels; ++label) { 657 if (label_types[label] == EinsumHelper::kBatch) 658 result_labels.push_back(label); 659 } 660 for (int i = 0; i < num_inputs; ++i) { 661 for (int label : free_labels[i]) { 662 result_labels.push_back(label); 663 result_shape.AddDim(label_to_dim_sizes[label]); 664 } 665 } 666 667 // Reshape the contraction (or reduction) result to its expanded shape: 668 // [(broadcasted) batch shape] + [free shape 0] + [free shape 1]. 669 Tensor contraction_output; 670 OP_REQUIRES_OK( 671 ctx, EinsumHelper::CopyFrom(contraction_output_reshaped, result_shape, 672 &contraction_output)); 673 674 // Inflate the output if necessary. (E.g. for the equation 'i->iii' which 675 // may arise while computing gradient of a regular Einsum). 676 // TODO(anudhyan): It's possible that Eigen's contract and inflate can be 677 // chained here to avoid materializing an intermediate. 678 Tensor output_inflated; 679 OP_REQUIRES_OK( 680 ctx, EinsumHelper::StrideOrInflate<Device, T>( 681 ctx, contraction_output, result_labels, output_label_counts, 682 true /* should_inflate */, &output_inflated)); 683 if (output_inflated.dims() > contraction_output.dims()) { 684 // We inflated the output. Modify result labels accordingly. 685 Labels inflated_labels; 686 for (int label : result_labels) { 687 inflated_labels.insert(inflated_labels.end(), 688 output_label_counts[label], label); 689 } 690 result_labels.swap(inflated_labels); 691 } 692 // Find the permutation to map the result labels to the output labels. Note 693 // that both the result and the final output may have the repeated labels, 694 // in which case the permutation preserves the left-to-right ordering. 695 // E.g. if result labels are [0, 0, 1] and output is [0, l, 0] then the 696 // permutation should be [0, 2, 1]. We also use the fact that repeated 697 // labels in the result are adjacent to each other. 698 std::vector<int> output_permutation(output_labels.size()); 699 std::vector<int> label_to_position(num_labels, -1); 700 for (int i = 0; i < result_labels.size(); ++i) { 701 // Remember the position of only the leftmost result label. 702 if (label_to_position[result_labels[i]] == -1) { 703 label_to_position[result_labels[i]] = i; 704 } 705 } 706 for (int i = 0; i < output_labels.size(); ++i) { 707 output_permutation[i] = label_to_position[output_labels[i]]; 708 // We have found the leftmost occurrence. The next one would be adjacent. 709 label_to_position[output_labels[i]] += 1; 710 } 711 Tensor output; 712 OP_REQUIRES_OK(ctx, EinsumHelper::TransposeOperand<Device, T>( 713 ctx, output_inflated, output_permutation, &output)); 714 ctx->set_output(0, output); 715 } 716 TraceString(const OpKernelContext & ctx,bool verbose)717 string TraceString(const OpKernelContext& ctx, bool verbose) const override { 718 string op = profiler::TraceMeOp(name_view(), type_string_view()); 719 string equation = strings::StrCat("(", equation_, ")"); 720 if (verbose) { 721 string shape = ShapeTraceString(ctx); 722 if (!shape.empty()) { 723 return profiler::TraceMeEncode( 724 std::move(op), {{"equation", equation}, {"shape", shape}}); 725 } 726 } 727 return profiler::TraceMeEncode(std::move(op), {{"equation", equation}}); 728 } 729 730 private: 731 string equation_; 732 OperandLabels input_labels_; 733 Labels output_labels_; 734 std::vector<EinsumHelper::DimensionType> label_types_; 735 OperandLabelCounts input_label_counts_; 736 LabelCounts output_label_counts_; 737 gtl::InlinedVector<bool, 2> input_has_ellipsis_; 738 bool output_has_ellipsis_ = false; 739 }; 740 741 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 742 // Forward declarations of the functor specializations for GPU. 743 namespace functor { 744 #define DECLARE_GPU_SPEC(T, N) \ 745 template <> \ 746 void StrideFunctor<GPUDevice, T, N>::operator()( \ 747 const GPUDevice& d, typename TTypes<T, N>::ConstTensor input, \ 748 const Eigen::DSizes<Eigen::DenseIndex, N>& strides, \ 749 typename TTypes<T, N>::Tensor output); \ 750 extern template struct StrideFunctor<GPUDevice, T, N>; \ 751 template <> \ 752 void InflateFunctor<GPUDevice, T, N>::operator()( \ 753 const GPUDevice& d, typename TTypes<T, N>::ConstTensor input, \ 754 const Eigen::DSizes<Eigen::DenseIndex, N>& strides, \ 755 typename TTypes<T, N>::Tensor output); \ 756 extern template struct InflateFunctor<GPUDevice, T, N>; 757 758 #define DECLARE_GPU_SPECS(T) \ 759 DECLARE_GPU_SPEC(T, 1); \ 760 DECLARE_GPU_SPEC(T, 2); \ 761 DECLARE_GPU_SPEC(T, 3); \ 762 DECLARE_GPU_SPEC(T, 4); \ 763 DECLARE_GPU_SPEC(T, 5); \ 764 DECLARE_GPU_SPEC(T, 6); 765 766 DECLARE_GPU_SPECS(Eigen::half); 767 DECLARE_GPU_SPECS(double); 768 DECLARE_GPU_SPECS(float); 769 // TODO(rocm): Enable once complex types are supported. 770 #if GOOGLE_CUDA 771 DECLARE_GPU_SPECS(complex64); 772 DECLARE_GPU_SPECS(complex128); 773 #endif 774 #undef DECLARE_GPU_SPEC 775 #undef DECLARE_GPU_SPECS 776 } // namespace functor 777 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 778 779 } // namespace tensorflow 780 781 #endif // TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_ 782