• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_
16 #define TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_
17 
18 #include "tensorflow/core/lib/core/status.h"
19 #include "tensorflow/core/lib/gtl/inlined_vector.h"
20 
21 namespace tensorflow {
22 
23 using Labels = gtl::InlinedVector<int, 8>;
24 using OperandLabels = gtl::InlinedVector<Labels, 2>;
25 using LabelCounts = gtl::InlinedVector<int, 8>;
26 using OperandLabelCounts = gtl::InlinedVector<LabelCounts, 2>;
27 
28 // Dummy axis label used to denote an ellipsis in an input or output subscript.
29 constexpr int kEllipsisLabel = -1;
30 
31 // Each dimension is categorized into exactly one of five types based on
32 // whether its corresponding label is present in the input and/or the output
33 // subscripts.
34 enum EinsumDimensionType {
35   // Batch dimensions are those present in two inputs as well as the output.
36   // They are part of the batch dimensions during Tensor contraction. Such
37   // dimensions may be broadcasting dimensions (those mapping to ellipsis)
38   // or explicit batch dimensions corresponding to named axis labels.
39   kBroadcasting = 0,
40   kBatch = 1,
41   // Free dimensions are present in exactly one of the inputs, and also the
42   // output. These are non-contracted axes in the Tensor contraction.
43   kFree = 2,
44   // Contract dimensions are present in two inputs, but not the output. These
45   // dimensions are contracted in Tensor contraction.
46   kContract = 3,
47   // Reduce dimensions are present in exactly one input; and not in the output
48   // and are summed over prior to Tensor contraction.
49   kReduce = 4,
50 };
51 
52 // Parses and validates an einsum equation in explicit form.
53 Status ValidateEinsumEquation(const string& equation,
54                               gtl::InlinedVector<string, 2>* input_subscripts,
55                               string* output_subscript);
56 
57 // Parses and validates the equation and the input shapes. Single character
58 // labels are integerized and we populate input and output label subscripts
59 // and corresponding counts. Also create the mapping from (named) labels to
60 // their EinsumDimensionType.
61 Status ParseEinsumEquation(const string& equation, OperandLabels* input_labels,
62                            Labels* output_labels,
63                            std::vector<EinsumDimensionType>* label_types,
64                            OperandLabelCounts* input_label_counts,
65                            LabelCounts* output_label_counts,
66                            gtl::InlinedVector<bool, 2>* input_has_ellipsis,
67                            bool* output_has_ellipsis);
68 
69 }  // namespace tensorflow
70 
71 #endif  // TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_
72