• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/c_api.h"
17 
18 #include <vector>
19 
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/eager/c_api_internal.h"
22 #ifdef TENSORFLOW_EAGER_USE_XLA
23 #include "tensorflow/compiler/jit/xla_device.h"
24 #endif  // TENSORFLOW_EAGER_USE_XLA
25 
26 using tensorflow::int64;
27 using tensorflow::string;
28 
29 namespace {
30 
TensorShapeAsVector(TFE_TensorHandle * handle,TF_Status * status)31 std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
32                                        TF_Status* status) {
33   std::vector<int64> shape;
34   int rank = TFE_TensorHandleNumDims(handle, status);
35   if (TF_GetCode(status) != TF_OK) {
36     return shape;
37   }
38   shape.reserve(rank);
39   for (int i = 0; i < rank; ++i) {
40     shape.push_back(TFE_TensorHandleDim(handle, i, status));
41     if (TF_GetCode(status) != TF_OK) {
42       return shape;
43     }
44   }
45   return shape;
46 }
47 
48 }  // namespace
49 
50 extern "C" {
51 
TFE_TensorHandleTensorDebugInfo(TFE_TensorHandle * handle,TF_Status * status)52 TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
53     TFE_TensorHandle* handle, TF_Status* status) {
54   const tensorflow::Tensor* tensor;
55   status->status = handle->handle->Tensor(&tensor);
56   if (TF_GetCode(status) != TF_OK) {
57     return nullptr;
58   }
59 
60 #ifdef TENSORFLOW_EAGER_USE_XLA
61   tensorflow::Device* device = handle->handle->device();
62 
63   // If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
64   tensorflow::XlaDevice* xla_device =
65       dynamic_cast<tensorflow::XlaDevice*>(device);
66   if (xla_device != nullptr) {
67     tensorflow::XlaDevice::PaddedShapeFn shape_fn =
68         xla_device->metadata().padded_shape_fn();
69     xla::Shape padded_shape;
70     status->status = shape_fn(*tensor, &padded_shape);
71     if (!status->status.ok()) {
72       return nullptr;
73     }
74     if (VLOG_IS_ON(3)) {
75       std::vector<int64> shape_to_log = TensorShapeAsVector(handle, status);
76       if (!status->status.ok()) {
77         // Ignore the status here as we are simply logging.
78         status->status = tensorflow::Status::OK();
79       } else {
80         VLOG(3) << "Fully padded shape of ["
81                 << tensorflow::str_util::Join(shape_to_log, ", ") << "] is "
82                 << padded_shape.DebugString();
83       }
84     }
85 
86     if (padded_shape.IsTuple()) {
87       if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) {
88         // Currently, the only case of XlaTensor containing a tuple shape is to
89         // represent 64 bit ints, doubles, and complex numbers (we don't support
90         // 64bit complex numbers).
91         status->status = tensorflow::errors::InvalidArgument(
92             "XlaTensors should only contain tuples of size 2. Shape: ",
93             padded_shape.DebugString());
94         return nullptr;
95       }
96 
97       // shape0 is not a const& because we will assign it to padded_shape below.
98       // It is illegal to assign a part of a message to itself.
99       xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0);
100       const xla::Shape& shape1 =
101           xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
102       if (shape0.IsTuple() || shape1.IsTuple()) {
103         status->status = tensorflow::errors::InvalidArgument(
104             "XlaTensors should not contain nested tuples. Shape: ",
105             padded_shape.DebugString());
106         return nullptr;
107       }
108       if (!xla::ShapeUtil::Equal(shape0, shape1)) {
109         status->status = tensorflow::errors::InvalidArgument(
110             "Subshapes of XlaTensors should be the same. Shape: ",
111             padded_shape.DebugString());
112         return nullptr;
113       }
114 
115       // Since the only case we handle here are two equal subshapes, we
116       // simply return one of them. The caller will interpret it as this
117       // shape directly storing the 64bit types. This approximation is good
118       // enough for this API's debugging use case.
119       padded_shape = shape0;
120     }
121 
122     int rank = padded_shape.dimensions_size();
123     std::vector<int64> dev_dims;
124     dev_dims.reserve(rank);
125     if (rank == 1) {
126       // Rank 1 tensors might not have padded_shape.layout.minor_to_major set,
127       dev_dims.push_back(padded_shape.dimensions(0));
128     } else {
129       for (int i = rank - 1; i >= 0; --i) {
130         int64 dim_index = padded_shape.layout().minor_to_major(i);
131         dev_dims.push_back(padded_shape.dimensions(dim_index));
132       }
133     }
134     status->status = tensorflow::Status::OK();
135     return new TFE_TensorDebugInfo(dev_dims);
136   }
137 #endif  // TENSORFLOW_EAGER_USE_XLA
138 
139   // If the tensor is not an XLA tensor, the device shape is
140   // the same as regular tensor shape.
141   std::vector<int64> dev_dims = TensorShapeAsVector(handle, status);
142   if (TF_GetCode(status) != TF_OK) {
143     return nullptr;
144   }
145   return new TFE_TensorDebugInfo(dev_dims);
146 }
147 
TFE_DeleteTensorDebugInfo(TFE_TensorDebugInfo * debug_info)148 TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
149     TFE_TensorDebugInfo* debug_info) {
150   delete debug_info;
151 }
152 
TFE_TensorDebugInfoOnDeviceNumDims(TFE_TensorDebugInfo * debug_info)153 TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(
154     TFE_TensorDebugInfo* debug_info) {
155   return debug_info->dev_dims.size();
156 }
157 
TFE_TensorDebugInfoOnDeviceDim(TFE_TensorDebugInfo * debug_info,int dim_index)158 TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim(
159     TFE_TensorDebugInfo* debug_info, int dim_index) {
160   return debug_info->dev_dims[dim_index];
161 }
162 
163 }  // extern "C"
164