• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_
18 
19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 
23 namespace tensorflow {
24 
25 // Whether the input expression at `index` corresponds to a TensorList.
26 bool IsTensorListInput(XlaOpKernelContext* ctx, int index);
27 
28 // Whether the TensorList is initialized (has known data type and shape).
29 Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized);
30 
31 // Whether the TensorList is a nested TensorList.
32 // Input must be an initialized TensorList.
33 // Non-nested and nested TensorLists are both supported.
34 Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list);
35 
36 // Builds a non-nested TensorList from `buffer` and `push_index`.
37 Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index,
38                                 xla::XlaOp* output_list);
39 
40 // Returns buffer shape for the TensorList.
41 // Input must be an initialized TensorList.
42 // Non-nested and nested TensorLists are both supported.
43 Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape);
44 
45 // Returns buffer for the TensorList.
46 // Input must be an initialized TensorList.
47 // Non-nested and nested TensorLists are both supported.
48 Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer);
49 
50 // Returns push index for the TensorList.
51 // Input must be an initialized TensorList.
52 // Non-nested and nested TensorLists are both supported.
53 Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index);
54 
55 // Returns a new TensorList with given push_index.
56 // Input must be an initialized TensorList.
57 // Non-nested and nested TensorLists are both supported.
58 Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index,
59                               xla::XlaOp* result);
60 
61 // Returns an uninitialized TensorList.
62 xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b,
63                                         int64 leading_dimension,
64                                         bool leading_size_is_dynamic,
65                                         xla::XlaOp leading_dim_size);
66 
67 // Returns leading dimension for the TensorList as well as a dynamic op
68 // representing the dynamic size. Input can be initialized or uninitialized
69 // TensorList. Non-nested and nested TensorLists are both supported.
70 Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim,
71                                   bool* leading_dim_is_dynamic,
72                                   xla::XlaOp* leading_dim_dynamic_size);
73 
74 // Returns TensorList shape for the element shape.
75 // Element shape must be a normal tensor shape.
76 Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
77                                           int64 leading_dim,
78                                           bool leading_dim_is_dynamic,
79                                           xla::Shape* tensor_list_shape);
80 
81 // Returns a TensorList filled by zeros with the given shape.
82 Status CreateZerosTensorListWithShape(
83     xla::XlaBuilder* b, const xla::Shape& list_shape,
84     const std::vector<std::vector<xla::XlaOp>>& dynamic_dims, xla::XlaOp* list);
85 
86 // If the TensorList is initialized, check that its shape matches element shape;
87 // If the TensorList is uninitialized, initialize it with the element shape.
88 // Input can be initialized or uninitialized TensorList.
89 // "element" can be normal tensor or TensorList.
90 Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
91                                           bool element_is_tensor_list,
92                                           xla::XlaOp* initialized_list);
93 
94 // Executes TensorListPushBack with given TensorList and element.
95 // Input must be an initialized TensorList.
96 // Non-nested and nested TensorLists are both supported.
97 Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element,
98                                  bool element_is_tensor_list,
99                                  xla::XlaOp* result);
100 
101 // Executes TensorListPopBack with given TensorList.
102 // Input must be an initialized TensorList.
103 // Non-nested and nested TensorLists are both supported.
104 Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result,
105                                 xla::XlaOp* element_result,
106                                 bool* element_is_tensor_list);
107 
108 // Executes TensorListSetItem with given TensorList, index and element.
109 // Input must be an initialized TensorList.
110 // Only non-nested TensorList is supported.
111 Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index,
112                                 xla::XlaOp element, xla::XlaOp* result);
113 
114 // Executes TensorListGetItem with given TensorList and index.
115 // Input must be an initialized TensorList.
116 // Only non-nested TensorList is supported.
117 Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index,
118                                 xla::XlaOp* result);
119 
120 // Executes TensorListPushBack with given tensor and push index.
121 // "tensor" must be a normal tensor.
122 Status ExecuteTensorListFromTensor(int push_index, xla::XlaOp tensor,
123                                    xla::XlaOp* result);
124 
125 }  // namespace tensorflow
126 
127 #endif  // TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_
128