• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_DTENSOR_CC_DTENSOR_UTILS_H_
17 #define TENSORFLOW_DTENSOR_CC_DTENSOR_UTILS_H_
18 
19 namespace tensorflow {
20 namespace dtensor {
21 
22 // Returns the DTensor client ID of this process, usually equal to the TF task
23 // ID on this host.
24 int ClientId();
25 
26 // Returns the total number of DTensor clients, usually equal to the total
27 // number of TF tasks.
28 int NumClients();
29 
30 // Returns whether to enable logging for passes and layouts on all passes.
31 bool LogOnAllTasks();
32 
33 // Returns whether to log op-by-op execution in addition to function execution
34 // when logging is enabled.
35 bool LogOpByOp();
36 
37 // Returns the maximum number of steps to run layout propagation. If the number
38 // of steps exceeds this amount, layout propagation will fail.
39 int LayoutPropagationMaxSteps();
40 
41 // Returns whether to upcast bfloat16 reduction inputs to float32 for
42 // sufficient reduction group size.
43 bool EnableMixedPrecisionReduce();
44 
45 // Returns whether *not* to fuse AllReduce + AllScatter into ReduceScatter op,
46 // which can be more efficiently implemented.
47 bool DoNotFuseReduceScatter();
48 
49 // Returns the maximum reduction group size for bfloat16 reduction. If the
50 // group size exceeds this, then tensors are upcasted to float32 before the
51 // reduce op.
52 int ReduceInBfloat16MaxGroupSize();
53 
54 // Returns DTensor Checkpointing version 2 is enabled.
55 bool DTensorCheckpointV2Enabled();
56 }  // namespace dtensor
57 }  // namespace tensorflow
58 
59 #endif  // TENSORFLOW_DTENSOR_CC_DTENSOR_UTILS_H_
60