1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ 18 19 #include "absl/types/optional.h" 20 #include "tensorflow/compiler/tf2xla/xla_resource.h" 21 #include "tensorflow/compiler/xla/client/client.h" 22 #include "tensorflow/compiler/xla/client/xla_builder.h" 23 #include "tensorflow/compiler/xla/statusor.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/lib/core/status.h" 26 27 namespace tensorflow { 28 29 // A XlaExpression represents a symbolic TensorFlow value in a TF->XLA 30 // compilation. 31 // An expression is one of: 32 // * a constant tensor. 33 // * an xla::XlaOp, representing a symbolic XLA value. 34 // * a resource, e.g., a variable, represented as an XlaResource pointer. 35 // * a tensor list, represented by a tuple of tensors and the list length. 36 // 37 // Constant tensors are mostly an optimization to avoid passing large constants 38 // to XLA, but are also sometimes used to represent tensors that have no XLA 39 // representation, for example, DT_STRING tensors. A canonical use case might be 40 // an error message string. 41 // 42 // Tensor lists are very similar to xla::XlaOp, however they require some 43 // specific logic around shape management since the tuples are not supported by 44 // TensorFlow. 45 class XlaExpression { 46 public: 47 enum class Kind { 48 kInvalid, 49 kConstant, 50 kXlaOp, 51 kResource, 52 kTensorList, 53 }; 54 55 XlaExpression(); 56 XlaExpression(const XlaExpression&) = default; 57 XlaExpression& operator=(const XlaExpression&) = default; 58 59 // Builds an invalid expression. (Same as the default constructor, but makes 60 // the intent clearer.) 61 static XlaExpression Invalid(); 62 63 // Builds a constant XLA expression. 64 static XlaExpression Constant(Tensor value); 65 66 // Builds a XlaOp expression. Since the mapping from TF data types to XLA 67 // types is not 1-1, the TF type must also be provided; in general it cannot 68 // be derived from the XLA type. 69 static XlaExpression XlaOp(xla::XlaOp value, DataType dtype); 70 71 // Builds a tensor list expression. 72 static XlaExpression TensorList(xla::XlaOp tensor_list); 73 74 // Builds a resource expression. 75 static XlaExpression Resource(XlaResource* resource); 76 kind()77 Kind kind() const { return kind_; } 78 dtype()79 DataType dtype() const { return dtype_; } 80 81 // handle() returns the XlaOp that backs a kXlaOp expression. handle()82 const xla::XlaOp& handle() const { return handle_; } 83 constant_value()84 const Tensor& constant_value() const { return constant_value_; } 85 resource()86 XlaResource* resource() const { return resource_; } 87 88 // Returns a human-readable summary of the expression. 89 string HumanString() const; 90 91 // Returns the value of a kConstant or kXlaOp as an xla::XlaOp. Returns 92 // an erroneous XlaOp if the expression is not a constant or an expression. 93 xla::XlaOp AsXlaOp(xla::XlaBuilder* builder) const; 94 95 // If a kXlaOp or kConstant expression can be resolved to a compile-time 96 // constant, returns the value as a host-memory Tensor. Returns an empty 97 // optional if it cannot be resolved. Returns an error if passed a resource 98 // expression. 99 xla::StatusOr<absl::optional<Tensor>> ResolveConstant( 100 xla::Client* client) const; 101 102 // Returns the shape of the tensor. 103 // The shape of a resource is the shape of a resource handle (i.e., a scalar), 104 // not the shape of the resource's value. 105 xla::StatusOr<TensorShape> GetShape() const; 106 107 private: 108 Kind kind_ = Kind::kInvalid; 109 110 DataType dtype_ = DT_INVALID; 111 112 // The XLA handle of the expression's computation, if kind_ == kXlaOp or 113 // a tuple expression if kind_ == kTensorList. 114 xla::XlaOp handle_; 115 116 // The value of the constant, if kind_ == kConstant. 117 Tensor constant_value_; 118 119 // The resource, if kind_ == kResource. Not owned. 120 XlaResource* resource_ = nullptr; 121 }; 122 123 } // namespace tensorflow 124 125 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ 126