1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_
18
19 #include <iosfwd>
20 #include <string>
21 #include "tensorflow/compiler/xla/statusor.h"
22 #include "tensorflow/compiler/xla/types.h"
23
24 namespace xla {
25
26 // High-level optimizer instruction opcodes -- these are linear-algebra level
27 // opcodes. They are a flattened form of the UnaryOp, BinaryOp, ... opcodes
28 // present in the XLA service protobuf.
29 //
30 // See the XLA documentation for the semantics of each opcode.
31 //
32 // Each entry has the format:
33 // (enum_name, opcode_name)
34 // or
35 // (enum_name, opcode_name, p1 | p2 | ...)
36 //
37 // with p1, p2, ... are members of HloOpcodeProperty. They are combined
38 // using bitwise-or.
39 //
40 // Note: Do not use ':' in opcode names. It is used as a special character
41 // in these places:
42 // - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to
43 // separate the opcode from the fusion kind
44 // - In fully qualified names (HloInstruction::FullyQualifiedName()), to
45 // separate the qualifiers (name of the computation and potentially the
46 // fusion instruction) from the name
47 #define HLO_OPCODE_LIST(V) \
48 V(kAbs, "abs") \
49 V(kAdd, "add") \
50 V(kAtan2, "atan2") \
51 V(kBatchNormGrad, "batch-norm-grad") \
52 V(kBatchNormInference, "batch-norm-inference") \
53 V(kBatchNormTraining, "batch-norm-training") \
54 V(kBitcast, "bitcast") \
55 V(kBitcastConvert, "bitcast-convert") \
56 V(kBroadcast, "broadcast") \
57 V(kCall, "call", kHloOpcodeIsVariadic) \
58 V(kCeil, "ceil") \
59 V(kClamp, "clamp") \
60 V(kComplex, "complex") \
61 V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \
62 V(kConditional, "conditional") \
63 V(kConstant, "constant") \
64 V(kConvert, "convert") \
65 V(kConvolution, "convolution") \
66 V(kCopy, "copy") \
67 V(kCos, "cosine") \
68 V(kCrossReplicaSum, "cross-replica-sum") \
69 V(kCustomCall, "custom-call") \
70 V(kDivide, "divide") \
71 V(kDot, "dot") \
72 V(kDynamicSlice, "dynamic-slice") \
73 V(kDynamicUpdateSlice, "dynamic-update-slice") \
74 V(kEq, "equal-to", kHloOpcodeIsComparison) \
75 V(kExp, "exponential") \
76 V(kFft, "fft") \
77 V(kFloor, "floor") \
78 V(kFusion, "fusion", kHloOpcodeIsVariadic) \
79 V(kGather, "gather") \
80 V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \
81 V(kGetTupleElement, "get-tuple-element") \
82 V(kGt, "greater-than", kHloOpcodeIsComparison) \
83 V(kHostCompute, "host-compute") \
84 V(kImag, "imag") \
85 V(kInfeed, "infeed") \
86 V(kIsFinite, "is-finite") \
87 V(kLe, "less-than-or-equal-to", kHloOpcodeIsComparison) \
88 V(kLog, "log") \
89 V(kAnd, "and") \
90 V(kNot, "not") \
91 V(kOr, "or") \
92 V(kLt, "less-than", kHloOpcodeIsComparison) \
93 V(kMap, "map", kHloOpcodeIsVariadic) \
94 V(kMaximum, "maximum") \
95 V(kMinimum, "minimum") \
96 V(kMultiply, "multiply") \
97 V(kNe, "not-equal-to", kHloOpcodeIsComparison) \
98 V(kNegate, "negate") \
99 V(kOutfeed, "outfeed") \
100 V(kPad, "pad") \
101 V(kParameter, "parameter") \
102 V(kPower, "power") \
103 V(kReal, "real") \
104 V(kRecv, "recv") \
105 V(kRecvDone, "recv-done") \
106 V(kReduce, "reduce") \
107 V(kReducePrecision, "reduce-precision") \
108 V(kReduceWindow, "reduce-window") \
109 V(kRemainder, "remainder") \
110 V(kReshape, "reshape") \
111 V(kReverse, "reverse") \
112 V(kRng, "rng") \
113 V(kRoundNearestAfz, "round-nearest-afz") \
114 V(kSelect, "select") \
115 V(kSelectAndScatter, "select-and-scatter") \
116 V(kSend, "send") \
117 V(kSendDone, "send-done") \
118 V(kShiftLeft, "shift-left") \
119 V(kShiftRightArithmetic, "shift-right-arithmetic") \
120 V(kShiftRightLogical, "shift-right-logical") \
121 V(kSign, "sign") \
122 V(kSin, "sine") \
123 V(kSlice, "slice") \
124 V(kSort, "sort") \
125 V(kSubtract, "subtract") \
126 V(kTanh, "tanh") \
127 V(kTrace, "trace") \
128 V(kTranspose, "transpose") \
129 V(kTuple, "tuple", kHloOpcodeIsVariadic) \
130 V(kWhile, "while")
131
132 enum class HloOpcode {
133 #define DECLARE_ENUM(enum_name, opcode_name, ...) enum_name,
134 HLO_OPCODE_LIST(DECLARE_ENUM)
135 #undef DECLARE_ENUM
136 };
137
138 // List of properties associated with opcodes.
139 // Properties are defined as increasing powers of two, so that we can use
140 // bitwise-or to combine properties, and bitwise-and to test for them.
141 enum HloOpcodeProperty {
142 kHloOpcodeIsComparison = 1 << 0,
143 kHloOpcodeIsVariadic = 1 << 1,
144 };
145
146 // Returns a string representation of the opcode.
147 string HloOpcodeString(HloOpcode opcode);
148
149 // Returns a string representation of the opcode.
150 StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name);
151
152 inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) {
153 return os << HloOpcodeString(opcode);
154 }
155
156 // Returns true iff the given opcode is a comparison operation.
157 bool HloOpcodeIsComparison(HloOpcode opcode);
158
159 // Returns true iff the given opcode has variadic operands.
160 bool HloOpcodeIsVariadic(HloOpcode opcode);
161
162 // Returns the number of HloOpcode values.
HloOpcodeCount()163 inline const uint32_t HloOpcodeCount() {
164 #define HLO_COUNT_ONE(...) +1
165 #define HLO_XLIST_LENGTH(list) list(HLO_COUNT_ONE)
166 return HLO_XLIST_LENGTH(HLO_OPCODE_LIST);
167 }
168
169 } // namespace xla
170
171 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_
172