1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_ 16 #define TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_ 17 18 #include "tensorflow/core/lib/core/status.h" 19 #include "tensorflow/core/lib/gtl/inlined_vector.h" 20 21 namespace tensorflow { 22 23 using Labels = gtl::InlinedVector<int, 8>; 24 using OperandLabels = gtl::InlinedVector<Labels, 2>; 25 using LabelCounts = gtl::InlinedVector<int, 8>; 26 using OperandLabelCounts = gtl::InlinedVector<LabelCounts, 2>; 27 28 // Dummy axis label used to denote an ellipsis in an input or output subscript. 29 constexpr int kEllipsisLabel = -1; 30 31 // Each dimension is categorized into exactly one of five types based on 32 // whether its corresponding label is present in the input and/or the output 33 // subscripts. 34 enum EinsumDimensionType { 35 // Batch dimensions are those present in two inputs as well as the output. 36 // They are part of the batch dimensions during Tensor contraction. Such 37 // dimensions may be broadcasting dimensions (those mapping to ellipsis) 38 // or explicit batch dimensions corresponding to named axis labels. 39 kBroadcasting = 0, 40 kBatch = 1, 41 // Free dimensions are present in exactly one of the inputs, and also the 42 // output. These are non-contracted axes in the Tensor contraction. 43 kFree = 2, 44 // Contract dimensions are present in two inputs, but not the output. These 45 // dimensions are contracted in Tensor contraction. 46 kContract = 3, 47 // Reduce dimensions are present in exactly one input; and not in the output 48 // and are summed over prior to Tensor contraction. 49 kReduce = 4, 50 }; 51 52 // Parses and validates an einsum equation in explicit form. 53 Status ValidateEinsumEquation(const string& equation, 54 gtl::InlinedVector<string, 2>* input_subscripts, 55 string* output_subscript); 56 57 // Parses and validates the equation and the input shapes. Single character 58 // labels are integerized and we populate input and output label subscripts 59 // and corresponding counts. Also create the mapping from (named) labels to 60 // their EinsumDimensionType. 61 Status ParseEinsumEquation(const string& equation, OperandLabels* input_labels, 62 Labels* output_labels, 63 std::vector<EinsumDimensionType>* label_types, 64 OperandLabelCounts* input_label_counts, 65 LabelCounts* output_label_counts, 66 gtl::InlinedVector<bool, 2>* input_has_ellipsis, 67 bool* output_has_ellipsis); 68 69 } // namespace tensorflow 70 71 #endif // TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_ 72