• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 
16 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_HYBRID_CORE_OPS_UTILS_H_
17 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_HYBRID_CORE_OPS_UTILS_H_
18 #include <vector>
19 
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/platform/types.h"
22 
23 namespace tensorflow {
24 namespace tensorforest {
25 
26 // Returns the probability that the point falls to the left.
27 float LeftProbability(const Tensor& point, const Tensor& weight, float bias,
28                       int num_features);
29 
30 float LeftProbabilityK(const Tensor& point, std::vector<int32> feature_set,
31                        const Tensor& weight, float bias, int num_features,
32                        int k);
33 
34 // Returns a random set of num_features_to_pick features in the
35 // range [0, num_features).  Must return the same set of
36 // features for subsequent calls with the same tree_num, node_num, and
37 // random_seed.  This allows us to calculate feature sets between calls to ops
38 // without having to store their values.
39 void GetFeatureSet(int32 tree_num, int32 node_num, int32 random_seed,
40                    int32 num_features, int32 num_features_to_pick,
41                    std::vector<int32>* features);
42 
43 }  // namespace tensorforest
44 }  // namespace tensorflow
45 
46 #endif  // TENSORFLOW_CONTRIB_TENSOR_FOREST_HYBRID_CORE_OPS_UTILS_H_
47