1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_
18
19 #include <iosfwd>
20 #include <string>
21
22 #include "absl/types/optional.h"
23 #include "tensorflow/compiler/xla/comparison_util.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27
28 namespace xla {
29
30 // High-level optimizer instruction opcodes -- these are linear-algebra level
31 // opcodes. They are a flattened form of the UnaryOp, BinaryOp, ... opcodes
32 // present in the XLA service protobuf.
33 //
34 // See the XLA documentation for the semantics of each opcode.
35 //
36 // Each entry has the format:
37 // (enum_name, opcode_name, arity)
38 //
39 // Note: Do not use ':' in opcode names. It is used as a special character
40 // in these places:
41 // - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to
42 // separate the opcode from the fusion kind
43 // - In fully qualified names (HloInstruction::FullyQualifiedName()), to
44 // separate the qualifiers (name of the computation and potentially the
45 // fusion instruction) from the name
46 #define HLO_OPCODE_LIST(V) \
47 V(kAbs, "abs", 1) \
48 V(kAdd, "add", 2) \
49 V(kAddDependency, "add-dependency", 2) \
50 V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \
51 V(kAllGather, "all-gather", 1) \
52 V(kAllReduce, "all-reduce", kHloOpcodeIsVariadic) \
53 V(kAllToAll, "all-to-all", kHloOpcodeIsVariadic) \
54 V(kAtan2, "atan2", 2) \
55 V(kBatchNormGrad, "batch-norm-grad", 5) \
56 V(kBatchNormInference, "batch-norm-inference", 5) \
57 V(kBatchNormTraining, "batch-norm-training", 3) \
58 V(kBitcast, "bitcast", 1) \
59 V(kBitcastConvert, "bitcast-convert", 1) \
60 V(kBroadcast, "broadcast", 1) \
61 V(kCall, "call", kHloOpcodeIsVariadic) \
62 V(kCeil, "ceil", 1) \
63 V(kCholesky, "cholesky", 1) \
64 V(kClamp, "clamp", 3) \
65 V(kCollectivePermute, "collective-permute", 1) \
66 V(kCollectivePermuteStart, "collective-permute-start", 1) \
67 V(kCollectivePermuteDone, "collective-permute-done", 1) \
68 V(kClz, "count-leading-zeros", 1) \
69 V(kCompare, "compare", 2) \
70 V(kComplex, "complex", 2) \
71 V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \
72 V(kConditional, "conditional", kHloOpcodeIsVariadic) \
73 V(kConstant, "constant", 0) \
74 V(kConvert, "convert", 1) \
75 V(kConvolution, "convolution", 2) \
76 V(kCopy, "copy", 1) \
77 V(kCopyDone, "copy-done", 1) \
78 V(kCopyStart, "copy-start", 1) \
79 V(kCos, "cosine", 1) \
80 V(kCustomCall, "custom-call", kHloOpcodeIsVariadic) \
81 V(kDivide, "divide", 2) \
82 V(kDomain, "domain", 1) \
83 V(kDot, "dot", 2) \
84 V(kDynamicSlice, "dynamic-slice", kHloOpcodeIsVariadic) \
85 V(kDynamicUpdateSlice, "dynamic-update-slice", kHloOpcodeIsVariadic) \
86 V(kExp, "exponential", 1) \
87 V(kExpm1, "exponential-minus-one", 1) \
88 V(kFft, "fft", 1) \
89 V(kFloor, "floor", 1) \
90 V(kFusion, "fusion", kHloOpcodeIsVariadic) \
91 V(kGather, "gather", 2) \
92 V(kGetDimensionSize, "get-dimension-size", 1) \
93 V(kSetDimensionSize, "set-dimension-size", 2) \
94 V(kGetTupleElement, "get-tuple-element", 1) \
95 V(kImag, "imag", 1) \
96 V(kInfeed, "infeed", 1) \
97 V(kIota, "iota", 0) \
98 V(kIsFinite, "is-finite", 1) \
99 V(kLog, "log", 1) \
100 V(kLog1p, "log-plus-one", 1) \
101 V(kLogistic, "logistic", 1) \
102 V(kAnd, "and", 2) \
103 V(kNot, "not", 1) \
104 V(kOr, "or", 2) \
105 V(kXor, "xor", 2) \
106 V(kMap, "map", kHloOpcodeIsVariadic) \
107 V(kMaximum, "maximum", 2) \
108 V(kMinimum, "minimum", 2) \
109 V(kMultiply, "multiply", 2) \
110 V(kNegate, "negate", 1) \
111 V(kOutfeed, "outfeed", 2) \
112 V(kPad, "pad", 2) \
113 V(kParameter, "parameter", 0) \
114 V(kPartitionId, "partition-id", 0) \
115 V(kPopulationCount, "popcnt", 1) \
116 V(kPower, "power", 2) \
117 V(kReal, "real", 1) \
118 V(kRecv, "recv", 1) \
119 V(kRecvDone, "recv-done", 1) \
120 V(kReduce, "reduce", kHloOpcodeIsVariadic) \
121 V(kReducePrecision, "reduce-precision", 1) \
122 V(kReduceWindow, "reduce-window", kHloOpcodeIsVariadic) \
123 V(kRemainder, "remainder", 2) \
124 V(kReplicaId, "replica-id", 0) \
125 V(kReshape, "reshape", 1) \
126 V(kDynamicReshape, "dynamic-reshape", kHloOpcodeIsVariadic) \
127 V(kReverse, "reverse", 1) \
128 V(kRng, "rng", kHloOpcodeIsVariadic) \
129 V(kRngGetAndUpdateState, "rng-get-and-update-state", 0) \
130 V(kRngBitGenerator, "rng-bit-generator", 1) \
131 V(kRoundNearestAfz, "round-nearest-afz", 1) \
132 V(kRsqrt, "rsqrt", 1) \
133 V(kScatter, "scatter", 3) \
134 V(kSelect, "select", 3) \
135 V(kSelectAndScatter, "select-and-scatter", 3) \
136 V(kSend, "send", 2) \
137 V(kSendDone, "send-done", 1) \
138 V(kShiftLeft, "shift-left", 2) \
139 V(kShiftRightArithmetic, "shift-right-arithmetic", 2) \
140 V(kShiftRightLogical, "shift-right-logical", 2) \
141 V(kSign, "sign", 1) \
142 V(kSin, "sine", 1) \
143 V(kSlice, "slice", 1) \
144 V(kSort, "sort", kHloOpcodeIsVariadic) \
145 V(kSqrt, "sqrt", 1) \
146 V(kCbrt, "cbrt", 1) \
147 V(kSubtract, "subtract", 2) \
148 V(kTanh, "tanh", 1) \
149 V(kTrace, "trace", 1) \
150 V(kTranspose, "transpose", 1) \
151 V(kTriangularSolve, "triangular-solve", 2) \
152 V(kTuple, "tuple", kHloOpcodeIsVariadic) \
153 V(kTupleSelect, "tuple-select", 3) \
154 V(kWhile, "while", 1)
155
156 enum class HloOpcode {
157 #define DECLARE_ENUM(enum_name, opcode_name, ...) enum_name,
158 HLO_OPCODE_LIST(DECLARE_ENUM)
159 #undef DECLARE_ENUM
160 };
161
162 // Arity value that denotes that an operator is variadic.
163 enum {
164 kHloOpcodeIsVariadic = -1,
165 };
166
167 // Returns a string representation of the opcode.
168 string HloOpcodeString(HloOpcode opcode);
169
170 // Retrieves the opcode enum by name if the opcode exists.
171 StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name);
172
173 inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) {
174 return os << HloOpcodeString(opcode);
175 }
176
177 // Returns true iff the given opcode is a comparison operation.
178 bool HloOpcodeIsComparison(HloOpcode opcode);
179
180 // Returns true iff the given opcode has variadic operands.
181 bool HloOpcodeIsVariadic(HloOpcode opcode);
182
183 // Returns the arity of opcode. If the opcode is variadic,
184 // returns nullopt.
185 absl::optional<int> HloOpcodeArity(HloOpcode opcode);
186
187 // Returns the number of HloOpcode values.
HloOpcodeCount()188 inline const uint32_t HloOpcodeCount() {
189 #define HLO_COUNT_ONE(...) +1
190 #define HLO_XLIST_LENGTH(list) list(HLO_COUNT_ONE)
191 return HLO_XLIST_LENGTH(HLO_OPCODE_LIST);
192 }
193
194 } // namespace xla
195
196 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_
197