• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // =============================================================================
2 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 // =============================================================================
16 
17 #include "tensorflow/core/framework/common_shape_fns.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/shape_inference.h"
21 
22 namespace tensorflow {
23 
24 REGISTER_OP("PeriodicResample")
25     .Attr("T: numbertype")
26     .Input("values: T")
27     .Attr("shape: shape")
28     .Output("output: T")
__anon96b295990102(shape_inference::InferenceContext* c) 29     .SetShapeFn([](shape_inference::InferenceContext* c) {
30       tensorflow::PartialTensorShape desired_shape;
31       TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape));
32       shape_inference::ShapeHandle input_tensor_shape = c->input(0);
33       shape_inference::DimensionHandle num_input_elements =
34           c->NumElements(input_tensor_shape);
35       shape_inference::ShapeHandle result_shape_handle;
36       if (!shape_inference::InferenceContext::ValueKnown(num_input_elements)) {
37         TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
38             desired_shape, &result_shape_handle));
39       } else {
40         const int rank = c->Rank(input_tensor_shape);
41         std::vector<tensorflow::int64> target_dimensions(rank);
42         tensorflow::int64 new_sliced_size = 1;
43         int adjustable_dimension = 0;
44         for (int i = 0; i < rank; ++i) {
45           if (desired_shape.dim_size(i) < 1) {
46             adjustable_dimension = i;
47           } else {
48             target_dimensions[i] = desired_shape.dim_size(i);
49             new_sliced_size *= target_dimensions[i];
50           }
51         }
52         target_dimensions[adjustable_dimension] =
53             shape_inference::InferenceContext::Value(
54                 num_input_elements) / new_sliced_size;
55         tensorflow::TensorShape result_shape;
56         for (int i = 0; i < rank; ++i) {
57           result_shape.AddDim(target_dimensions[i]);
58         }
59         TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(
60             result_shape, &result_shape_handle));
61       }
62       c->set_output(0, result_shape_handle);
63       return Status::OK();
64     })
65     .Doc(R"doc(
66 Periodically resample elements of a tensor to conform to `shape`.
67 
68 This function implements a slightly more generic version of the subpixel
69 convolutions found in this [paper](https://arxiv.org/abs/1609.05158).
70 
71 The formula for computing the elements in the `output` tensor is as follows:
72 
73   `T` = `values` tensor of rank `R`
74 
75   `S` = desired `shape` of output tensor (vector of length `R`)
76 
77   `P` = `output` tensor of rank `R`
78 
79   \\((T_1,\\ldots,T_R)\\) = shape(`T`)
80 
81   \\([S_1,\\ldots,S_q,\\ldots,S_R]\\) = elements of vector `S`
82 
83   A single element in `S` is left unspecified (denoted \\(S_q=-1\\)).
84 
85   Let \\(f_i\\) denote the (possibly non-integer) factor that relates the original
86   dimension to the desired dimensions, \\(S_i=f_i T_i\\), for \\(i\\neq q\\) where
87   \\(f_i>0\\).
88 
89   Define the following:
90 
91   \\(g_i=\\lceil f_i\\rceil\\)
92 
93   \\(t=\\prod_i T_i\\)
94 
95   \\(s=\\prod_{i\\neq q} S_i\\)
96 
97   \\(S_q\\) can then be defined by \\(S_q=\\lfloor t/s\\rfloor\\).
98   The elements of the resulting tensor are defined as
99 
100   \\(P_{s_1,\\ldots,s_R}=T_{h_1,\\ldots,h_q,\\ldots,h_R}\\).
101 
102   The \\(h_i\\) (\\(i\\neq q\\)) are defined by \\(h_i=\\lfloor s_i/g_i\\rfloor\\).
103 
104   \\(h_q=S_q\\sum_{j\\neq q}^{q-1}G_j \\mathrm{mod}(s_j,g_j) + s_q\\), where
105   \\(G_j=\\prod_{i}^{j-1}g_i\\) (\\(G_0=1\\)).
106 
107 One drawback of this method is that whenever the output dimensions are slightly
108 less than integer multiples of the input dimensions, many of the tensor elements
109 are repeated in an inefficient way. This is resolved by specifying that all
110 desired dimensions are integer multiples of the input tensor.
111 
112 For example:
113 
114 ```prettyprint
115 `input` is [[ 0  1  2  3]
116             [ 4  5  6  7]
117             [ 8  9 10 11]]
118 
119 tf.periodic_resample(input, [6, None]) ==> [[ 0  1]
120                                             [ 2  3]
121                                             [ 4  5]
122                                             [ 6  7]
123                                             [ 8  9]
124                                             [10 11]]
125 ```
126 
127 values: The tensor of rank `R` to periodic_resample
128 shape: A 1-D tensor representing the desired shape of the output tensor.
129   Exactly one element of this tensor must have the value `None` which represents
130   that this dimension of `values` can be adjusted downward in order to
131   accommodate increases in other dimensions. The specified sizes of the
132   non-adjustable dimensions must by at least as large as in the `values` tensor.
133 output: Periodically resampled tensor that has dimensions specified as in
134   `shape` except that the dimension specified as `None` will be minimally
135   decreased as necessary.
136 
137 )doc");
138 
139 
140 REGISTER_OP("PeriodicResampleOpGrad")
141     .Attr("T: numbertype")
142     .Input("grad: T")
143     .Attr("original_shape: shape")
144     .Attr("desired_shape: shape")
145     .Output("grad_values: T")
__anon96b295990202(shape_inference::InferenceContext* c) 146     .SetShapeFn([](shape_inference::InferenceContext* c) {
147       tensorflow::TensorShape original_shape;
148       TF_RETURN_IF_ERROR(c->GetAttr("original_shape", &original_shape));
149       shape_inference::ShapeHandle s;
150       TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(original_shape, &s));
151       c->set_output(0, s);
152       return Status::OK();
153 });
154 
155 }  // namespace tensorflow
156