1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_query.h"
17
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
20 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23
24 namespace xla {
25 namespace hlo_query {
26
IsConstantR0F32(HloInstruction * instruction,float * out)27 bool IsConstantR0F32(HloInstruction* instruction, float* out) {
28 if (instruction->opcode() == HloOpcode::kConstant &&
29 ShapeUtil::IsScalarWithElementType(instruction->shape(), F32)) {
30 *out = instruction->literal().Get<float>({});
31 return true;
32 }
33
34 return false;
35 }
36
AllOperandsAreParametersOrConstants(const HloInstruction & instruction)37 bool AllOperandsAreParametersOrConstants(const HloInstruction& instruction) {
38 for (const auto& operand : instruction.operands()) {
39 if (operand->opcode() != HloOpcode::kParameter &&
40 operand->opcode() != HloOpcode::kConstant) {
41 return false;
42 }
43 }
44 return true;
45 }
46
AllOperandsAreParameters(const HloInstruction & instruction)47 bool AllOperandsAreParameters(const HloInstruction& instruction) {
48 for (const auto& operand : instruction.operands()) {
49 if (operand->opcode() != HloOpcode::kParameter) {
50 return false;
51 }
52 }
53 return true;
54 }
55
AllOperandsAreConstants(const HloInstruction & instruction)56 bool AllOperandsAreConstants(const HloInstruction& instruction) {
57 for (const auto& operand : instruction.operands()) {
58 if (operand->opcode() != HloOpcode::kConstant) {
59 return false;
60 }
61 }
62 return true;
63 }
64
GetMatchingOperand(const std::function<bool (const HloInstruction *)> & matcher,HloInstruction * instruction)65 HloInstruction* GetMatchingOperand(
66 const std::function<bool(const HloInstruction*)>& matcher,
67 HloInstruction* instruction) {
68 for (HloInstruction* op : instruction->operands()) {
69 if (matcher(op)) {
70 return op;
71 }
72 }
73 return nullptr;
74 }
75
MatchBinaryInstructionOperand(const std::function<bool (const HloInstruction *)> & matcher,HloInstruction * instruction,HloInstruction ** matching_operand,HloInstruction ** other_operand)76 bool MatchBinaryInstructionOperand(
77 const std::function<bool(const HloInstruction*)>& matcher,
78 HloInstruction* instruction, HloInstruction** matching_operand,
79 HloInstruction** other_operand) {
80 CHECK_EQ(instruction->operand_count(), 2);
81 if (matcher(instruction->operand(0))) {
82 *matching_operand = instruction->mutable_operand(0);
83 *other_operand = instruction->mutable_operand(1);
84 return true;
85 }
86 if (matcher(instruction->operand(1))) {
87 *matching_operand = instruction->mutable_operand(1);
88 *other_operand = instruction->mutable_operand(0);
89 return true;
90 }
91 return false;
92 }
93
MatchBinaryInstructionOperandOpcode(HloOpcode opcode,HloInstruction * instruction,HloInstruction ** matching_operand,HloInstruction ** other_operand)94 bool MatchBinaryInstructionOperandOpcode(HloOpcode opcode,
95 HloInstruction* instruction,
96 HloInstruction** matching_operand,
97 HloInstruction** other_operand) {
98 return MatchBinaryInstructionOperand(
99 [opcode](const HloInstruction* instruction) {
100 return instruction->opcode() == opcode;
101 },
102 instruction, matching_operand, other_operand);
103 }
104
IsScalarConstant(const HloInstruction * instruction)105 bool IsScalarConstant(const HloInstruction* instruction) {
106 return instruction->IsConstant() && ShapeUtil::IsScalar(instruction->shape());
107 }
108
ContainsInstrWithOpcode(const HloComputation * comp,const absl::flat_hash_set<HloOpcode> & opcodes)109 bool ContainsInstrWithOpcode(const HloComputation* comp,
110 const absl::flat_hash_set<HloOpcode>& opcodes) {
111 for (const auto* instr : comp->instructions()) {
112 if (opcodes.count(instr->opcode())) {
113 return true;
114 }
115 for (const HloComputation* subcomp : instr->called_computations()) {
116 if (ContainsInstrWithOpcode(subcomp, opcodes)) {
117 return true;
118 }
119 }
120 }
121 return false;
122 }
123
ContainsLayoutConstrainedAllReduce(const HloModule & module)124 bool ContainsLayoutConstrainedAllReduce(const HloModule& module) {
125 for (auto computation : module.computations()) {
126 for (auto hlo : computation->instructions()) {
127 if (hlo->opcode() == HloOpcode::kAllReduce &&
128 DynCast<HloAllReduceInstruction>(hlo)->constrain_layout()) {
129 return true;
130 }
131 }
132 }
133 return false;
134 }
135
NextChannelId(const HloModule & module)136 int64 NextChannelId(const HloModule& module) {
137 int64 next_channel_id = 1;
138 for (const HloComputation* comp : module.computations()) {
139 for (const HloInstruction* hlo : comp->instructions()) {
140 const HloChannelInstruction* channel_instr =
141 DynCast<HloChannelInstruction>(hlo);
142 if (channel_instr && channel_instr->channel_id()) {
143 next_channel_id =
144 std::max(next_channel_id, *channel_instr->channel_id() + 1);
145 }
146 }
147 }
148 return next_channel_id;
149 }
150
HasX64TransformedHostTransfer(const HloModule & module)151 bool HasX64TransformedHostTransfer(const HloModule& module) {
152 for (auto computation : module.computations()) {
153 for (auto hlo : computation->instructions()) {
154 if (hlo->opcode() == HloOpcode::kSend) {
155 auto send = DynCast<HloSendInstruction>(hlo);
156 if (send->is_host_transfer() && send->operand(0)->shape().IsTuple()) {
157 return true;
158 }
159 } else if (hlo->opcode() == HloOpcode::kRecv) {
160 auto recv = DynCast<HloRecvInstruction>(hlo);
161 if (recv->is_host_transfer() &&
162 recv->shape().tuple_shapes(0).IsTuple()) {
163 return true;
164 }
165 }
166 }
167 }
168 return false;
169 }
170
171 } // namespace hlo_query
172 } // namespace xla
173