• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_COLLECTION_OPS_UTIL_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_COLLECTION_OPS_UTIL_H_
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/Location.h"  // from @llvm-project
22 #include "mlir/IR/Value.h"  // from @llvm-project
23 #include "mlir/Support/LLVM.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
26 
27 namespace mlir {
28 namespace TF {
29 namespace collection_ops_util {
30 
31 // This file includes utilities for decomposing collection ops (stack, tensor
32 // list, tensor array) in TF. We represent such a data structure as a buffer of
33 // shape [max_element_count, element_shape].
34 
35 // Creates an i32 scalar tf.Const.
36 Value CreateScalarConst(int32_t value, OpBuilder builder, Location loc);
37 
38 // Creates an integer vector tf.Const.
39 Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc,
40                  int bitwidth = 32);
41 
42 // Returns the type of the size tensor used to track a data structure's element
43 // count. It is a tensor<1xi32>, and we use R1 instead of a scalar because it is
44 // easier to concat it with other offsets.
45 TensorType GetSizeType(OpBuilder builder);
46 
47 // Reshapes a scalar value to match the size type tensor<i32>.
48 Value ReshapeScalarToSizeType(OpBuilder builder, Value scalar, Location loc);
49 
50 // Creates ops that represent the indices of the slice for an element in the
51 // buffer. Requires `index` to have tensor<1xi32> type.
52 Value GetIndicesForElement(Value index, Value buffer, OpBuilder builder,
53                            Location loc);
54 
55 // Creates ops that slice the element out of a buffer at the given index.
56 // Requires `index` to have tensor<1xi32> type.
57 Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc,
58                  bool keep_slice_shape = false);
59 
60 // Creates ops that copy the buffer and update an element at the given index.
61 // Requires `index` to have tensor<1xi32> type.
62 Value SetElement(Value index, Value buffer, Value element, OpBuilder builder,
63                  Location loc);
64 
65 // Creates the buffer for the data structure with given element shape, type and
66 // maximum size.
67 LogicalResult CreateInitBufferValue(ArrayRef<int64_t> element_shape,
68                                     int64_t max_size, Operation* op,
69                                     Type element_dtype, OpBuilder builder,
70                                     Value* buffer);
71 
72 // Same as above, but uses a Value as max_size and check if it is a constant.
73 LogicalResult CreateInitBufferValue(ArrayRef<int64_t> element_shape,
74                                     Value max_size, Operation* op,
75                                     Type element_dtype, OpBuilder builder,
76                                     Value* buffer);
77 
78 // Tries to infer the element type with full shape based its write accesses.
79 // `infer_from_user` should check if the provided op is an accessing op that
80 // could be used to infer the type.
81 llvm::Optional<RankedTensorType> GetElementTypeFromAccess(
82     Value collection, ModuleOp module,
83     llvm::function_ref<llvm::Optional<Type>(Operation*)> infer_from_op);
84 
85 // Creates a ReadVariableOp on a local variable.
86 Value ReadLocalVariable(Value local_var, OpBuilder builder, Location loc);
87 
88 // Creates an AssignVariableOp on a local variable.
89 TF::AssignVariableOp WriteLocalVariable(Value local_var, Value value,
90                                         OpBuilder builder, Location loc);
91 
92 // Adds two values, or creates a logical-or if they are boolean type.
93 Value AccumulateBuffers(Value a, Value b, OpBuilder builder, Location loc);
94 
95 // Gathers elements in buffer with the indices.
96 Value GatherElements(Value indices, Value buffer, OpBuilder builder,
97                      Location loc);
98 
99 // Scatters elements into buffer, where each scattered element is accumulated
100 // with the old value in buffer.
101 Value ScatterAccumulateElements(Value indices, Value updates, Value buffer,
102                                 OpBuilder builder, Location loc);
103 
104 }  // namespace collection_ops_util
105 }  // namespace TF
106 }  // namespace mlir
107 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_COLLECTION_OPS_UTIL_H_
108