1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_query.h"
17
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
20 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23
24 namespace xla {
25 namespace hlo_query {
26
27 namespace {
IsCollectiveCommunicationOp(HloOpcode op)28 bool IsCollectiveCommunicationOp(HloOpcode op) {
29 return op == HloOpcode::kAllReduce || op == HloOpcode::kAllGather ||
30 op == HloOpcode::kAllToAll || op == HloOpcode::kCollectivePermute ||
31 op == HloOpcode::kReduceScatter;
32 }
33 } // namespace
34
IsConstantR0F32(HloInstruction * instruction,float * out)35 bool IsConstantR0F32(HloInstruction* instruction, float* out) {
36 if (instruction->opcode() == HloOpcode::kConstant &&
37 ShapeUtil::IsScalarWithElementType(instruction->shape(), F32)) {
38 *out = instruction->literal().Get<float>({});
39 return true;
40 }
41
42 return false;
43 }
44
AllOperandsAreParametersOrConstants(const HloInstruction & instruction)45 bool AllOperandsAreParametersOrConstants(const HloInstruction& instruction) {
46 for (const auto& operand : instruction.operands()) {
47 if (operand->opcode() != HloOpcode::kParameter &&
48 operand->opcode() != HloOpcode::kConstant) {
49 return false;
50 }
51 }
52 return true;
53 }
54
AllOperandsAreParameters(const HloInstruction & instruction)55 bool AllOperandsAreParameters(const HloInstruction& instruction) {
56 for (const auto& operand : instruction.operands()) {
57 if (operand->opcode() != HloOpcode::kParameter) {
58 return false;
59 }
60 }
61 return true;
62 }
63
AllOperandsAreConstants(const HloInstruction & instruction)64 bool AllOperandsAreConstants(const HloInstruction& instruction) {
65 for (const auto& operand : instruction.operands()) {
66 if (operand->opcode() != HloOpcode::kConstant) {
67 return false;
68 }
69 }
70 return true;
71 }
72
GetMatchingOperand(const std::function<bool (const HloInstruction *)> & matcher,HloInstruction * instruction)73 HloInstruction* GetMatchingOperand(
74 const std::function<bool(const HloInstruction*)>& matcher,
75 HloInstruction* instruction) {
76 for (HloInstruction* op : instruction->operands()) {
77 if (matcher(op)) {
78 return op;
79 }
80 }
81 return nullptr;
82 }
83
MatchBinaryInstructionOperand(const std::function<bool (const HloInstruction *)> & matcher,HloInstruction * instruction,HloInstruction ** matching_operand,HloInstruction ** other_operand)84 bool MatchBinaryInstructionOperand(
85 const std::function<bool(const HloInstruction*)>& matcher,
86 HloInstruction* instruction, HloInstruction** matching_operand,
87 HloInstruction** other_operand) {
88 CHECK_EQ(instruction->operand_count(), 2);
89 if (matcher(instruction->operand(0))) {
90 *matching_operand = instruction->mutable_operand(0);
91 *other_operand = instruction->mutable_operand(1);
92 return true;
93 }
94 if (matcher(instruction->operand(1))) {
95 *matching_operand = instruction->mutable_operand(1);
96 *other_operand = instruction->mutable_operand(0);
97 return true;
98 }
99 return false;
100 }
101
MatchBinaryInstructionOperandOpcode(HloOpcode opcode,HloInstruction * instruction,HloInstruction ** matching_operand,HloInstruction ** other_operand)102 bool MatchBinaryInstructionOperandOpcode(HloOpcode opcode,
103 HloInstruction* instruction,
104 HloInstruction** matching_operand,
105 HloInstruction** other_operand) {
106 return MatchBinaryInstructionOperand(
107 [opcode](const HloInstruction* instruction) {
108 return instruction->opcode() == opcode;
109 },
110 instruction, matching_operand, other_operand);
111 }
112
IsScalarConstant(const HloInstruction * instruction)113 bool IsScalarConstant(const HloInstruction* instruction) {
114 return instruction->IsConstant() && ShapeUtil::IsScalar(instruction->shape());
115 }
116
ContainsInstrWithOpcode(const HloComputation * comp,const absl::flat_hash_set<HloOpcode> & opcodes)117 bool ContainsInstrWithOpcode(const HloComputation* comp,
118 const absl::flat_hash_set<HloOpcode>& opcodes) {
119 for (const auto* instr : comp->instructions()) {
120 if (opcodes.count(instr->opcode())) {
121 return true;
122 }
123 for (const HloComputation* subcomp : instr->called_computations()) {
124 if (ContainsInstrWithOpcode(subcomp, opcodes)) {
125 return true;
126 }
127 }
128 }
129 return false;
130 }
131
ContainsLayoutConstrainedCollective(const HloModule & module,HloOpcode op)132 bool ContainsLayoutConstrainedCollective(const HloModule& module,
133 HloOpcode op) {
134 CHECK(IsCollectiveCommunicationOp(op));
135
136 for (auto computation : module.computations()) {
137 for (auto hlo : computation->instructions()) {
138 if (hlo->opcode() == op &&
139 DynCast<HloCollectiveInstruction>(hlo)->constrain_layout()) {
140 return true;
141 }
142 }
143 }
144 return false;
145 }
146
NextChannelId(const HloModule & module)147 int64 NextChannelId(const HloModule& module) {
148 int64_t next_channel_id = 1;
149 for (const HloComputation* comp : module.computations()) {
150 for (const HloInstruction* hlo : comp->instructions()) {
151 const HloChannelInstruction* channel_instr =
152 DynCast<HloChannelInstruction>(hlo);
153 if (channel_instr && channel_instr->channel_id()) {
154 next_channel_id =
155 std::max(next_channel_id, *channel_instr->channel_id() + 1);
156 }
157 }
158 }
159 return next_channel_id;
160 }
161
HasX64TransformedHostTransfer(const HloModule & module)162 bool HasX64TransformedHostTransfer(const HloModule& module) {
163 for (auto computation : module.computations()) {
164 for (auto hlo : computation->instructions()) {
165 if (hlo->opcode() == HloOpcode::kSend) {
166 auto send = DynCast<HloSendInstruction>(hlo);
167 if (send->is_host_transfer() && send->operand(0)->shape().IsTuple()) {
168 return true;
169 }
170 } else if (hlo->opcode() == HloOpcode::kRecv) {
171 auto recv = DynCast<HloRecvInstruction>(hlo);
172 if (recv->is_host_transfer() &&
173 recv->shape().tuple_shapes(0).IsTuple()) {
174 return true;
175 }
176 }
177 }
178 }
179 return false;
180 }
181
182 } // namespace hlo_query
183 } // namespace xla
184