• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
18 
19 #include "absl/types/optional.h"
20 #include "tensorflow/compiler/tf2xla/xla_resource.h"
21 #include "tensorflow/compiler/xla/client/client.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/lib/core/status.h"
26 
27 namespace tensorflow {
28 
29 // A XlaExpression represents a symbolic TensorFlow value in a TF->XLA
30 // compilation.
31 // An expression is one of:
32 // * a constant tensor.
33 // * an xla::XlaOp, representing a symbolic XLA value.
34 // * a resource, e.g., a variable, represented as an XlaResource pointer.
35 // * a tensor list, represented by a tuple of tensors and the list length.
36 //
37 // Constant tensors are mostly an optimization to avoid passing large constants
38 // to XLA, but are also sometimes used to represent tensors that have no XLA
39 // representation, for example, DT_STRING tensors. A canonical use case might be
40 // an error message string.
41 //
42 // Tensor lists are very similar to xla::XlaOp, however they require some
43 // specific logic around shape management since the tuples are not supported by
44 // TensorFlow.
45 class XlaExpression {
46  public:
47   enum class Kind {
48     kInvalid,
49     kConstant,
50     kXlaOp,
51     kResource,
52     kTensorList,
53   };
54 
55   XlaExpression();
56   XlaExpression(const XlaExpression&) = default;
57   XlaExpression& operator=(const XlaExpression&) = default;
58 
59   // Builds an invalid expression. (Same as the default constructor, but makes
60   // the intent clearer.)
61   static XlaExpression Invalid();
62 
63   // Builds a constant XLA expression.
64   static XlaExpression Constant(Tensor value);
65 
66   // Builds a XlaOp expression. Since the mapping from TF data types to XLA
67   // types is not 1-1, the TF type must also be provided; in general it cannot
68   // be derived from the XLA type.
69   static XlaExpression XlaOp(xla::XlaOp value, DataType dtype);
70 
71   // Builds a tensor list expression.
72   static XlaExpression TensorList(xla::XlaOp tensor_list);
73 
74   // Builds a resource expression.
75   static XlaExpression Resource(XlaResource* resource);
76 
77   // Builds a resource whose value is known at a compile time.
78   static XlaExpression ConstantResource(Tensor value, XlaResource* resource);
79 
kind()80   Kind kind() const { return kind_; }
81 
dtype()82   DataType dtype() const { return dtype_; }
83 
84   // handle() returns the XlaOp that backs a kXlaOp expression.
handle()85   const xla::XlaOp& handle() const { return handle_; }
86 
87   // Return a constant value associated with this expression. Always set for
88   // constants, might be set for resources.
constant_value()89   absl::optional<Tensor> constant_value() const {
90     if (kind_ == Kind::kResource && resource_->IsOverwritten()) {
91       // The constant is no longer available if the value was overwritten.
92       return absl::nullopt;
93     }
94     return constant_value_;
95   }
96 
97   // Set the bound of the expression.
set_value_bound(Tensor tensor)98   void set_value_bound(Tensor tensor) {
99     value_bound_.emplace(std::move(tensor));
100   }
101 
102   // Return the bound of the expression, if available.
value_bound()103   absl::optional<Tensor> value_bound() const { return value_bound_; }
resource()104   XlaResource* resource() const { return resource_; }
105 
106   // Returns a human-readable summary of the expression.
107   string HumanString() const;
108 
109   // Returns the value of a kConstant or kXlaOp as an xla::XlaOp. Returns
110   // an erroneous XlaOp if the expression is not a constant or an expression.
111   xla::XlaOp AsXlaOp(xla::XlaBuilder* builder) const;
112 
113   // If a kXlaOp or kConstant expression can be resolved to a compile-time
114   // constant, returns the value as a host-memory Tensor. Returns an empty
115   // optional if it cannot be resolved. Returns an error if passed a resource
116   // expression.
117   xla::StatusOr<absl::optional<Tensor>> ResolveConstant(
118       xla::Client* client, bool dynamic_dimension_is_minus_one = false) const;
119 
120   // ResolveDynamism computes where a value inside this op is dynamic or can be
121   // inferred at compile time.
122   xla::StatusOr<Tensor> ResolveDynamism(xla::Client* client) const;
123 
124   // Returns the shape of the tensor.
125   // The shape of a resource is the shape of a resource handle (i.e., a scalar),
126   // not the shape of the resource's value.
127   xla::StatusOr<TensorShape> GetShape() const;
128 
129   // Retrieves an XlaExpression that was allocated by a previous Op.
130   static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor);
131 
132   // Assigns an XlaExpression to a tensor on an XLA compilation device.
133   static void AssignExpressionToTensor(const XlaExpression& value,
134                                        Tensor* tensor);
135 
136  private:
137   Kind kind_ = Kind::kInvalid;
138 
139   DataType dtype_ = DT_INVALID;
140 
141   // The XLA handle of the expression's computation, if kind_ == kXlaOp or
142   // a tuple expression if kind_ == kTensorList.
143   xla::XlaOp handle_;
144 
145   // The value of the constant, if available.
146   absl::optional<Tensor> constant_value_;
147 
148   // The bound of the expression, if available.
149   absl::optional<Tensor> value_bound_;
150 
151   // The resource, if kind_ == kResource. Not owned.
152   XlaResource* resource_ = nullptr;
153 };
154 
155 }  // namespace tensorflow
156 
157 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
158