• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_query.h"
17 
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
20 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 
24 namespace xla {
25 namespace hlo_query {
26 
27 namespace {
IsCollectiveCommunicationOp(HloOpcode op)28 bool IsCollectiveCommunicationOp(HloOpcode op) {
29   return op == HloOpcode::kAllReduce || op == HloOpcode::kAllGather ||
30          op == HloOpcode::kAllToAll || op == HloOpcode::kCollectivePermute ||
31          op == HloOpcode::kReduceScatter;
32 }
33 }  // namespace
34 
IsConstantR0F32(HloInstruction * instruction,float * out)35 bool IsConstantR0F32(HloInstruction* instruction, float* out) {
36   if (instruction->opcode() == HloOpcode::kConstant &&
37       ShapeUtil::IsScalarWithElementType(instruction->shape(), F32)) {
38     *out = instruction->literal().Get<float>({});
39     return true;
40   }
41 
42   return false;
43 }
44 
AllOperandsAreParametersOrConstants(const HloInstruction & instruction)45 bool AllOperandsAreParametersOrConstants(const HloInstruction& instruction) {
46   for (const auto& operand : instruction.operands()) {
47     if (operand->opcode() != HloOpcode::kParameter &&
48         operand->opcode() != HloOpcode::kConstant) {
49       return false;
50     }
51   }
52   return true;
53 }
54 
AllOperandsAreParameters(const HloInstruction & instruction)55 bool AllOperandsAreParameters(const HloInstruction& instruction) {
56   for (const auto& operand : instruction.operands()) {
57     if (operand->opcode() != HloOpcode::kParameter) {
58       return false;
59     }
60   }
61   return true;
62 }
63 
AllOperandsAreConstants(const HloInstruction & instruction)64 bool AllOperandsAreConstants(const HloInstruction& instruction) {
65   for (const auto& operand : instruction.operands()) {
66     if (operand->opcode() != HloOpcode::kConstant) {
67       return false;
68     }
69   }
70   return true;
71 }
72 
GetMatchingOperand(const std::function<bool (const HloInstruction *)> & matcher,HloInstruction * instruction)73 HloInstruction* GetMatchingOperand(
74     const std::function<bool(const HloInstruction*)>& matcher,
75     HloInstruction* instruction) {
76   for (HloInstruction* op : instruction->operands()) {
77     if (matcher(op)) {
78       return op;
79     }
80   }
81   return nullptr;
82 }
83 
MatchBinaryInstructionOperand(const std::function<bool (const HloInstruction *)> & matcher,HloInstruction * instruction,HloInstruction ** matching_operand,HloInstruction ** other_operand)84 bool MatchBinaryInstructionOperand(
85     const std::function<bool(const HloInstruction*)>& matcher,
86     HloInstruction* instruction, HloInstruction** matching_operand,
87     HloInstruction** other_operand) {
88   CHECK_EQ(instruction->operand_count(), 2);
89   if (matcher(instruction->operand(0))) {
90     *matching_operand = instruction->mutable_operand(0);
91     *other_operand = instruction->mutable_operand(1);
92     return true;
93   }
94   if (matcher(instruction->operand(1))) {
95     *matching_operand = instruction->mutable_operand(1);
96     *other_operand = instruction->mutable_operand(0);
97     return true;
98   }
99   return false;
100 }
101 
MatchBinaryInstructionOperandOpcode(HloOpcode opcode,HloInstruction * instruction,HloInstruction ** matching_operand,HloInstruction ** other_operand)102 bool MatchBinaryInstructionOperandOpcode(HloOpcode opcode,
103                                          HloInstruction* instruction,
104                                          HloInstruction** matching_operand,
105                                          HloInstruction** other_operand) {
106   return MatchBinaryInstructionOperand(
107       [opcode](const HloInstruction* instruction) {
108         return instruction->opcode() == opcode;
109       },
110       instruction, matching_operand, other_operand);
111 }
112 
IsScalarConstant(const HloInstruction * instruction)113 bool IsScalarConstant(const HloInstruction* instruction) {
114   return instruction->IsConstant() && ShapeUtil::IsScalar(instruction->shape());
115 }
116 
ContainsInstrWithOpcode(const HloComputation * comp,const absl::flat_hash_set<HloOpcode> & opcodes)117 bool ContainsInstrWithOpcode(const HloComputation* comp,
118                              const absl::flat_hash_set<HloOpcode>& opcodes) {
119   for (const auto* instr : comp->instructions()) {
120     if (opcodes.count(instr->opcode())) {
121       return true;
122     }
123     for (const HloComputation* subcomp : instr->called_computations()) {
124       if (ContainsInstrWithOpcode(subcomp, opcodes)) {
125         return true;
126       }
127     }
128   }
129   return false;
130 }
131 
ContainsLayoutConstrainedCollective(const HloModule & module,HloOpcode op)132 bool ContainsLayoutConstrainedCollective(const HloModule& module,
133                                          HloOpcode op) {
134   CHECK(IsCollectiveCommunicationOp(op));
135 
136   for (auto computation : module.computations()) {
137     for (auto hlo : computation->instructions()) {
138       if (hlo->opcode() == op &&
139           DynCast<HloCollectiveInstruction>(hlo)->constrain_layout()) {
140         return true;
141       }
142     }
143   }
144   return false;
145 }
146 
NextChannelId(const HloModule & module)147 int64 NextChannelId(const HloModule& module) {
148   int64_t next_channel_id = 1;
149   for (const HloComputation* comp : module.computations()) {
150     for (const HloInstruction* hlo : comp->instructions()) {
151       const HloChannelInstruction* channel_instr =
152           DynCast<HloChannelInstruction>(hlo);
153       if (channel_instr && channel_instr->channel_id()) {
154         next_channel_id =
155             std::max(next_channel_id, *channel_instr->channel_id() + 1);
156       }
157     }
158   }
159   return next_channel_id;
160 }
161 
HasX64TransformedHostTransfer(const HloModule & module)162 bool HasX64TransformedHostTransfer(const HloModule& module) {
163   for (auto computation : module.computations()) {
164     for (auto hlo : computation->instructions()) {
165       if (hlo->opcode() == HloOpcode::kSend) {
166         auto send = DynCast<HloSendInstruction>(hlo);
167         if (send->is_host_transfer() && send->operand(0)->shape().IsTuple()) {
168           return true;
169         }
170       } else if (hlo->opcode() == HloOpcode::kRecv) {
171         auto recv = DynCast<HloRecvInstruction>(hlo);
172         if (recv->is_host_transfer() &&
173             recv->shape().tuple_shapes(0).IsTuple()) {
174           return true;
175         }
176       }
177     }
178   }
179   return false;
180 }
181 
182 }  // namespace hlo_query
183 }  // namespace xla
184