• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include <stdlib.h>
16 #include <time.h>
17 #include <algorithm>
18 #include <cmath>
19 #include <memory>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h"
26 #include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/lib/gtl/top_n.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/work_sharder.h"
33 
34 namespace tensorflow {
35 
36 using tensorforest::LeftProbabilityK;
37 
38 REGISTER_OP("KFeatureGradient")
39     .Attr("layer_num: int")
40     .Attr("random_seed: int")
41     .Input("input_data: float")
42     .Input("tree_parameters: float")
43     .Input("tree_biases: float")
44     .Input("routes: float")
45     .Output("routing_gradient: float")
46     .Output("data_gradient: float")
47     .Output("weight_gradient: float")
48     .Doc(R"doc(
49     Computes the derivative of the routing loss with respect to each decision
50     node.  Each decision node is constrained to make a decision based on only
51     k features.
52 
53     layer_num: The layer number of this tree.
54     random_seed: The base random seed.
55 
56     input_data: The training batch's features as a 2-d tensor;
57      `input_data[i][j]` gives the j-th feature of the i-th input.
58     tree_parameters: `tree_parameters[i]` gives the weight of
59      the logistic regression model that translates from node features to
60      probabilities.
61     tree_biases: `tree_biases[i]` gives the bias of the logistic
62      regression model that translates from node features to
63      probabilities.
64     routes: The routes computed by routing_function_op.
65 
66     routing_gradient: `routing_gradient` provides du / df, where u is the
67      routing function and f is the (vector of) decision functions.  A decision
68      function f_i computes the routing decision at node i.
69 
70     data_gradient: `data_gradient` provides df / dx, where f is the (vector
71      of) decision functions and x is a batch of data.
72 
73     weights_gradient: `weights_gradient` provides df / dw, where f is the
74      (vector of) decision functions and w is the matrix of parameters that
75      determine how instances are routed through a tree.
76 
77     f_i, the decision function at node i, is parameterized by t_i (parameters)
78     and b_i (bias) and takes data x as input.  This op is called in
79     training_ops.py to compute du / df, and we use that to compute
80 
81     du / dx = du / df * df / dx,
82     du / dt = du / df * df / dt, and
83     du / db = du / df * df / db.
84 )doc");
85 
86 class KFeatureGradient : public OpKernel {
87  public:
KFeatureGradient(OpKernelConstruction * context)88   explicit KFeatureGradient(OpKernelConstruction* context) : OpKernel(context) {
89     OP_REQUIRES_OK(context, context->GetAttr("layer_num", &layer_num_));
90     OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_));
91   }
92 
Compute(OpKernelContext * context)93   void Compute(OpKernelContext* context) override {
94     // Gather input.
95     const Tensor& input_data_tensor = context->input(0);
96     const Tensor& tree_parameters_tensor = context->input(1);
97     const Tensor& tree_biases_tensor = context->input(2);
98     const Tensor& routing_tensor = context->input(3);
99 
100     // Extract dimensions from input tensors.
101     const int32 num_data =
102         static_cast<int32>(input_data_tensor.shape().dim_size(0));
103     const int32 num_features =
104         static_cast<int32>(input_data_tensor.shape().dim_size(1));
105     const int32 num_nodes =
106         static_cast<int32>(tree_parameters_tensor.shape().dim_size(0));
107     const int32 num_features_per_node =
108         static_cast<int32>(tree_parameters_tensor.shape().dim_size(1));
109 
110     // Construct output tensors.
111     Tensor* out_routes = nullptr;
112     TensorShape out_routes_shape;
113     out_routes_shape.AddDim(num_data);
114     out_routes_shape.AddDim(num_nodes);
115 
116     Tensor* out_data = nullptr;
117     TensorShape out_data_shape;
118     out_data_shape.AddDim(num_nodes);
119     out_data_shape.AddDim(num_features);
120 
121     Tensor* out_weights = nullptr;
122     TensorShape out_weights_shape;
123     out_weights_shape.AddDim(num_data);
124     out_weights_shape.AddDim(num_nodes);
125     out_weights_shape.AddDim(num_features_per_node);
126 
127     OP_REQUIRES_OK(context,
128                    context->allocate_output(0, out_routes_shape, &out_routes));
129     OP_REQUIRES_OK(context,
130                    context->allocate_output(1, out_data_shape, &out_data));
131     OP_REQUIRES_OK(
132         context, context->allocate_output(2, out_weights_shape, &out_weights));
133 
134     tensorforest::Initialize(*out_data, 0.0f);
135 
136     // Compute output.
137     const auto input_data = input_data_tensor.tensor<float, 2>();
138     const auto tree_parameters = tree_parameters_tensor.tensor<float, 2>();
139     const auto tree_biases = tree_biases_tensor.tensor<float, 1>();
140     const auto routes = routing_tensor.tensor<float, 2>();
141 
142     auto routes_grad = out_routes->tensor<float, 2>();
143     auto data_grad = out_data->tensor<float, 2>();
144     auto weights_grad = out_weights->tensor<float, 3>();
145 
146     std::vector<int32> feature_set;
147     for (int i = 0; i < num_data; i++) {
148       const Tensor point = input_data_tensor.Slice(i, i + 1);
149       feature_set.clear();
150 
151       // Traverse the tree from the bottom up.
152       for (int j = num_nodes - 1; j >= 0; j--) {
153         tensorforest::GetFeatureSet(layer_num_, j, random_seed_, num_features,
154                                     num_features_per_node, &feature_set);
155 
156         // Compute routing gradient.
157         // j is a leaf node.
158         if (j >= num_nodes / 2) {
159           routes_grad(i, j) = routes(i, j);
160         } else {  // j is not a leaf node
161           int32 left_child = 2 * j + 1;
162           int32 right_child = left_child + 1;
163 
164           float left_prob = LeftProbabilityK(
165               point, feature_set, tree_parameters_tensor.Slice(j, j + 1),
166               tree_biases(j), num_features, num_features_per_node);
167 
168           float right_prob = 1.0f - left_prob;
169 
170           routes_grad(i, j) = (right_prob * routes(i, left_child) +
171                                left_prob * routes(i, right_child));
172         }
173         // Compute data and weight gradient.
174         for (int k = 0; k < num_features_per_node; k++) {
175           CHECK_LT(feature_set[k], num_features);
176           data_grad(j, feature_set[k]) = tree_parameters(j, k);
177           weights_grad(i, j, k) = input_data(i, feature_set[k]);
178         }
179       }
180     }
181   }
182 
183  private:
184   int32 layer_num_;
185   int32 random_seed_;
186 };
187 
188 REGISTER_KERNEL_BUILDER(Name("KFeatureGradient").Device(DEVICE_CPU),
189                         KFeatureGradient);
190 }  // namespace tensorflow
191