1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/c_api.h"
17
18 #include <vector>
19
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/eager/c_api_internal.h"
22 #ifdef TENSORFLOW_EAGER_USE_XLA
23 #include "tensorflow/compiler/jit/xla_device.h"
24 #endif // TENSORFLOW_EAGER_USE_XLA
25
26 using tensorflow::int64;
27 using tensorflow::string;
28
29 namespace {
30
TensorShapeAsVector(TFE_TensorHandle * handle,TF_Status * status)31 std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
32 TF_Status* status) {
33 std::vector<int64> shape;
34 int rank = TFE_TensorHandleNumDims(handle, status);
35 if (TF_GetCode(status) != TF_OK) {
36 return shape;
37 }
38 shape.reserve(rank);
39 for (int i = 0; i < rank; ++i) {
40 shape.push_back(TFE_TensorHandleDim(handle, i, status));
41 if (TF_GetCode(status) != TF_OK) {
42 return shape;
43 }
44 }
45 return shape;
46 }
47
48 } // namespace
49
50 extern "C" {
51
TFE_TensorHandleTensorDebugInfo(TFE_TensorHandle * handle,TF_Status * status)52 TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
53 TFE_TensorHandle* handle, TF_Status* status) {
54 const tensorflow::Tensor* tensor;
55 status->status = handle->handle->Tensor(&tensor);
56 if (TF_GetCode(status) != TF_OK) {
57 return nullptr;
58 }
59
60 #ifdef TENSORFLOW_EAGER_USE_XLA
61 tensorflow::Device* device = handle->handle->device();
62
63 // If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
64 tensorflow::XlaDevice* xla_device =
65 dynamic_cast<tensorflow::XlaDevice*>(device);
66 if (xla_device != nullptr) {
67 tensorflow::XlaDevice::PaddedShapeFn shape_fn =
68 xla_device->metadata().padded_shape_fn();
69 xla::Shape padded_shape;
70 status->status = shape_fn(*tensor, &padded_shape);
71 if (!status->status.ok()) {
72 return nullptr;
73 }
74 if (VLOG_IS_ON(3)) {
75 std::vector<int64> shape_to_log = TensorShapeAsVector(handle, status);
76 if (!status->status.ok()) {
77 // Ignore the status here as we are simply logging.
78 status->status = tensorflow::Status::OK();
79 } else {
80 VLOG(3) << "Fully padded shape of ["
81 << tensorflow::str_util::Join(shape_to_log, ", ") << "] is "
82 << padded_shape.DebugString();
83 }
84 }
85
86 if (padded_shape.IsTuple()) {
87 if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) {
88 // Currently, the only case of XlaTensor containing a tuple shape is to
89 // represent 64 bit ints, doubles, and complex numbers (we don't support
90 // 64bit complex numbers).
91 status->status = tensorflow::errors::InvalidArgument(
92 "XlaTensors should only contain tuples of size 2. Shape: ",
93 padded_shape.DebugString());
94 return nullptr;
95 }
96
97 // shape0 is not a const& because we will assign it to padded_shape below.
98 // It is illegal to assign a part of a message to itself.
99 xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0);
100 const xla::Shape& shape1 =
101 xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
102 if (shape0.IsTuple() || shape1.IsTuple()) {
103 status->status = tensorflow::errors::InvalidArgument(
104 "XlaTensors should not contain nested tuples. Shape: ",
105 padded_shape.DebugString());
106 return nullptr;
107 }
108 if (!xla::ShapeUtil::Equal(shape0, shape1)) {
109 status->status = tensorflow::errors::InvalidArgument(
110 "Subshapes of XlaTensors should be the same. Shape: ",
111 padded_shape.DebugString());
112 return nullptr;
113 }
114
115 // Since the only case we handle here are two equal subshapes, we
116 // simply return one of them. The caller will interpret it as this
117 // shape directly storing the 64bit types. This approximation is good
118 // enough for this API's debugging use case.
119 padded_shape = shape0;
120 }
121
122 int rank = padded_shape.dimensions_size();
123 std::vector<int64> dev_dims;
124 dev_dims.reserve(rank);
125 if (rank == 1) {
126 // Rank 1 tensors might not have padded_shape.layout.minor_to_major set,
127 dev_dims.push_back(padded_shape.dimensions(0));
128 } else {
129 for (int i = rank - 1; i >= 0; --i) {
130 int64 dim_index = padded_shape.layout().minor_to_major(i);
131 dev_dims.push_back(padded_shape.dimensions(dim_index));
132 }
133 }
134 status->status = tensorflow::Status::OK();
135 return new TFE_TensorDebugInfo(dev_dims);
136 }
137 #endif // TENSORFLOW_EAGER_USE_XLA
138
139 // If the tensor is not an XLA tensor, the device shape is
140 // the same as regular tensor shape.
141 std::vector<int64> dev_dims = TensorShapeAsVector(handle, status);
142 if (TF_GetCode(status) != TF_OK) {
143 return nullptr;
144 }
145 return new TFE_TensorDebugInfo(dev_dims);
146 }
147
TFE_DeleteTensorDebugInfo(TFE_TensorDebugInfo * debug_info)148 TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
149 TFE_TensorDebugInfo* debug_info) {
150 delete debug_info;
151 }
152
TFE_TensorDebugInfoOnDeviceNumDims(TFE_TensorDebugInfo * debug_info)153 TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(
154 TFE_TensorDebugInfo* debug_info) {
155 return debug_info->dev_dims.size();
156 }
157
TFE_TensorDebugInfoOnDeviceDim(TFE_TensorDebugInfo * debug_info,int dim_index)158 TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim(
159 TFE_TensorDebugInfo* debug_info, int dim_index) {
160 return debug_info->dev_dims[dim_index];
161 }
162
163 } // extern "C"
164