1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ 18 19 #include "absl/types/optional.h" 20 #include "tensorflow/compiler/tf2xla/xla_resource.h" 21 #include "tensorflow/compiler/xla/client/client.h" 22 #include "tensorflow/compiler/xla/client/xla_builder.h" 23 #include "tensorflow/compiler/xla/statusor.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/lib/core/status.h" 26 27 namespace tensorflow { 28 29 // A XlaExpression represents a symbolic TensorFlow value in a TF->XLA 30 // compilation. 31 // An expression is one of: 32 // * a constant tensor. 33 // * an xla::XlaOp, representing a symbolic XLA value. 34 // * a resource, e.g., a variable, represented as an XlaResource pointer. 35 // * a tensor list, represented by a tuple of tensors and the list length. 36 // 37 // Constant tensors are mostly an optimization to avoid passing large constants 38 // to XLA, but are also sometimes used to represent tensors that have no XLA 39 // representation, for example, DT_STRING tensors. A canonical use case might be 40 // an error message string. 41 // 42 // Tensor lists are very similar to xla::XlaOp, however they require some 43 // specific logic around shape management since the tuples are not supported by 44 // TensorFlow. 45 class XlaExpression { 46 public: 47 enum class Kind { 48 kInvalid, 49 kConstant, 50 kXlaOp, 51 kResource, 52 kTensorList, 53 }; 54 55 XlaExpression(); 56 XlaExpression(const XlaExpression&) = default; 57 XlaExpression& operator=(const XlaExpression&) = default; 58 59 // Builds an invalid expression. (Same as the default constructor, but makes 60 // the intent clearer.) 61 static XlaExpression Invalid(); 62 63 // Builds a constant XLA expression. 64 static XlaExpression Constant(Tensor value); 65 66 // Builds a XlaOp expression. Since the mapping from TF data types to XLA 67 // types is not 1-1, the TF type must also be provided; in general it cannot 68 // be derived from the XLA type. 69 static XlaExpression XlaOp(xla::XlaOp value, DataType dtype); 70 71 // Builds a tensor list expression. 72 static XlaExpression TensorList(xla::XlaOp tensor_list); 73 74 // Builds a resource expression. 75 static XlaExpression Resource(XlaResource* resource); 76 77 // Builds a resource whose value is known at a compile time. 78 static XlaExpression ConstantResource(Tensor value, XlaResource* resource); 79 kind()80 Kind kind() const { return kind_; } 81 dtype()82 DataType dtype() const { return dtype_; } 83 84 // handle() returns the XlaOp that backs a kXlaOp expression. handle()85 const xla::XlaOp& handle() const { return handle_; } 86 87 // Return a constant value associated with this expression. Always set for 88 // constants, might be set for resources. constant_value()89 absl::optional<Tensor> constant_value() const { 90 if (kind_ == Kind::kResource && resource_->IsOverwritten()) { 91 // The constant is no longer available if the value was overwritten. 92 return absl::nullopt; 93 } 94 return constant_value_; 95 } 96 97 // Set the bound of the expression. set_value_bound(Tensor tensor)98 void set_value_bound(Tensor tensor) { 99 value_bound_.emplace(std::move(tensor)); 100 } 101 102 // Return the bound of the expression, if available. value_bound()103 absl::optional<Tensor> value_bound() const { return value_bound_; } resource()104 XlaResource* resource() const { return resource_; } 105 106 // Returns a human-readable summary of the expression. 107 string HumanString() const; 108 109 // Returns the value of a kConstant or kXlaOp as an xla::XlaOp. Returns 110 // an erroneous XlaOp if the expression is not a constant or an expression. 111 xla::XlaOp AsXlaOp(xla::XlaBuilder* builder) const; 112 113 // If a kXlaOp or kConstant expression can be resolved to a compile-time 114 // constant, returns the value as a host-memory Tensor. Returns an empty 115 // optional if it cannot be resolved. Returns an error if passed a resource 116 // expression. 117 xla::StatusOr<absl::optional<Tensor>> ResolveConstant( 118 xla::Client* client, bool dynamic_dimension_is_minus_one = false) const; 119 120 // ResolveDynamism computes where a value inside this op is dynamic or can be 121 // inferred at compile time. 122 xla::StatusOr<Tensor> ResolveDynamism(xla::Client* client) const; 123 124 // Returns the shape of the tensor. 125 // The shape of a resource is the shape of a resource handle (i.e., a scalar), 126 // not the shape of the resource's value. 127 xla::StatusOr<TensorShape> GetShape() const; 128 129 // Retrieves an XlaExpression that was allocated by a previous Op. 130 static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor); 131 132 // Assigns an XlaExpression to a tensor on an XLA compilation device. 133 static void AssignExpressionToTensor(const XlaExpression& value, 134 Tensor* tensor); 135 136 private: 137 Kind kind_ = Kind::kInvalid; 138 139 DataType dtype_ = DT_INVALID; 140 141 // The XLA handle of the expression's computation, if kind_ == kXlaOp or 142 // a tuple expression if kind_ == kTensorList. 143 xla::XlaOp handle_; 144 145 // The value of the constant, if available. 146 absl::optional<Tensor> constant_value_; 147 148 // The bound of the expression, if available. 149 absl::optional<Tensor> value_bound_; 150 151 // The resource, if kind_ == kResource. Not owned. 152 XlaResource* resource_ = nullptr; 153 }; 154 155 } // namespace tensorflow 156 157 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ 158