• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <sstream>
17 
18 #include "tensorflow/c/kernels.h"
19 #include "tensorflow/core/framework/common_shape_fns.h"
20 #include "tensorflow/core/framework/op.h"
21 #include "tensorflow/core/framework/selective_registration.h"
22 #include "tensorflow/core/framework/shape_inference.h"
23 #include "tensorflow/core/framework/types.h"
24 
25 // BitcastOp implements a bitcast kernel, creating an output tensor that shares
26 // the same data buffer as the input but with a different shape and/or data
27 // type. Its inputs are:
28 //
29 //   * the input tensor
30 //   * an attribute named "T" containing the TF_DataType of the input tensor
31 //   * an attribute named "type" containing the TF_DataType of the output tensor
32 //
33 // Given an input tensor of shape [...], if the input DataType "T" is larger
34 // than the output DataType "type", then the shape changes from [...]
35 // to [..., sizeof(T)/sizeof(type)].
36 //
37 // If "T" is smaller than "type", the operator requires that the rightmost
38 // dimension be equal to sizeof(type)/sizeof(T). The shape then goes from
39 // [..., sizeof(type)/sizeof(T)] to [...].
40 //
41 // Bitcast is implemented as a low-level cast, so machines with different endian
42 // orderings will give different results.
43 typedef struct BitcastOp {
44   TF_DataType input_data_type;
45   TF_DataType output_data_type;
46   size_t in_size;
47   size_t out_size;
48 } BitcastOp;
49 
BitcastOp_Create(TF_OpKernelConstruction * ctx)50 static void* BitcastOp_Create(TF_OpKernelConstruction* ctx) {
51   auto* kernel = new BitcastOp;
52 
53   TF_Status* s = TF_NewStatus();
54   TF_OpKernelConstruction_GetAttrType(ctx, "T", &kernel->input_data_type, s);
55 
56   if (TF_GetCode(s) == TF_OK) {
57     TF_OpKernelConstruction_GetAttrType(ctx, "type", &kernel->output_data_type,
58                                         s);
59   }
60 
61   if (TF_GetCode(s) == TF_OK) {
62     kernel->in_size = TF_DataTypeSize(kernel->input_data_type);
63     kernel->out_size = TF_DataTypeSize(kernel->output_data_type);
64 
65     size_t check_size = std::max(kernel->in_size, kernel->out_size) %
66                         std::min(kernel->in_size, kernel->out_size);
67     if (check_size != 0) {
68       std::ostringstream err;
69       err << "cannot convert between datatype " << kernel->input_data_type
70           << " and " << kernel->output_data_type;
71       TF_SetStatus(s, TF_INVALID_ARGUMENT, err.str().c_str());
72     }
73   }
74 
75   if (TF_GetCode(s) != TF_OK) {
76     TF_OpKernelConstruction_Failure(ctx, s);
77     delete kernel;
78     kernel = nullptr;
79   }
80 
81   TF_DeleteStatus(s);
82   return kernel;
83 }
84 
BitcastOp_Delete(void * kernel)85 static void BitcastOp_Delete(void* kernel) {
86   delete static_cast<BitcastOp*>(kernel);
87 }
88 
BitcastOp_Compute(void * kernel,TF_OpKernelContext * ctx)89 static void BitcastOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
90   auto* k = static_cast<BitcastOp*>(kernel);
91   int dim_count = 0;
92 
93   TF_Tensor* tensor;
94   TF_Status* status = TF_NewStatus();
95   TF_GetInput(ctx, 0, &tensor, status);
96   if (TF_GetCode(status) == TF_OK) {
97     dim_count = TF_NumDims(tensor);
98     if (!(k->in_size >= k->out_size ||
99           (dim_count > 0 &&
100            TF_Dim(tensor, dim_count - 1) == k->out_size / k->in_size))) {
101       std::ostringstream err;
102       err << "Cannot bitcast from " << k->input_data_type << " to "
103           << k->output_data_type;
104       TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
105     }
106   }
107 
108   if (TF_GetCode(status) == TF_OK) {
109     auto* dims = new int64_t[dim_count + 1];
110     int new_dim_count = dim_count;
111     for (int dim = 0; dim < dim_count; ++dim) {
112       dims[dim] = TF_Dim(tensor, dim);
113     }
114     if (k->out_size < k->in_size) {
115       dims[new_dim_count++] = static_cast<int64_t>(k->in_size / k->out_size);
116     } else if (k->out_size > k->in_size) {
117       --new_dim_count;
118     }
119 
120     TF_Tensor* output = TF_AllocateTensor(k->output_data_type, dims, 0,
121                                           TF_DataTypeSize(k->output_data_type));
122     TF_TensorBitcastFrom(tensor, k->output_data_type, output, dims,
123                          new_dim_count, status);
124     if (TF_GetCode(status) == TF_OK) {
125       TF_SetOutput(ctx, 0, output, status);
126     }
127     delete[] dims;
128     TF_DeleteTensor(output);
129   }
130 
131   if (TF_GetCode(status) != TF_OK) {
132     TF_OpKernelContext_Failure(ctx, status);
133   }
134   TF_DeleteStatus(status);
135   TF_DeleteTensor(tensor);
136 }
137 
RegisterBitcastOp()138 static void RegisterBitcastOp() {
139   TF_Status* status = TF_NewStatus();
140 
141   {
142     auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_CPU,
143                                         &BitcastOp_Create, &BitcastOp_Compute,
144                                         &BitcastOp_Delete);
145     TF_RegisterKernelBuilder("BitcastOp", builder, status);
146     CHECK_EQ(TF_OK, TF_GetCode(status))
147         << "Error while registering bitcast kernel";
148   }
149 
150 #if GOOGLE_CUDA
151   {
152     auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_GPU,
153                                         &BitcastOp_Create, &BitcastOp_Compute,
154                                         &BitcastOp_Delete);
155     TF_RegisterKernelBuilder("BitcastOp", builder, status);
156     CHECK_EQ(TF_OK, TF_GetCode(status))
157         << "Error while registering CUDA bitcast kernel";
158   }
159 #endif
160 
161   TF_DeleteStatus(status);
162 }
163 
164 // A dummy static variable initialized by a lambda whose side-effect is to
165 // register the bitcast kernel.
__anon1476fccf0102() 166 static bool BitcastOpIsRegistered = []() {
167   if (SHOULD_REGISTER_OP_KERNEL("BitcastOp")) {
168     RegisterBitcastOp();
169   }
170   return true;
171 }();
172