• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A class to manage slices of a tensor. You can "register" set of slices for a
17 // tensor and then "query" if we have data for a given slice.
18 
19 #ifndef TENSORFLOW_CORE_UTIL_TENSOR_SLICE_SET_H_
20 #define TENSORFLOW_CORE_UTIL_TENSOR_SLICE_SET_H_
21 
22 #include <string>  // for string
23 #include <unordered_map>
24 #include <vector>
25 
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/tensor_slice.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/status.h"       // for Status
30 #include "tensorflow/core/lib/core/stringpiece.h"  // for StringPiece
31 #include "tensorflow/core/platform/types.h"        // for int64
32 
33 namespace tensorflow {
34 
35 namespace checkpoint {
36 
37 class TensorSliceSet {
38  public:
39   TensorSliceSet(const TensorShape& shape, DataType type);
40   virtual ~TensorSliceSet();
41 
shape()42   const TensorShape& shape() const { return shape_; }
type()43   const DataType type() const { return type_; }
44 
45   // Register a new slice for the tensor. The "tag" is an arbitrary string
46   // associated with the slice (in one application it denotes the name of the
47   // file that contains the slice); the "data" points to the data of the tensor
48   // slice (it can be a nullptr).
49   Status Register(const TensorSlice& slice, const string& tag);
50 
51   // Alternative way of querying about a new slice: instead of copying the
52   // data, it returns a list of meta data about the stored slices that will
53   // supply data for the slice.
54   bool QueryMeta(
55       const TensorSlice& slice,
56       std::vector<std::pair<tensorflow::TensorSlice, string>>* results) const;
57 
58   struct SliceInfo {
59     TensorSlice slice;
60     const string tag;
61     int64_t num_floats;
62   };
63 
64   // Returns the map from slice string to SliceInfo.
Slices()65   const std::unordered_map<string, SliceInfo>& Slices() const {
66     return slices_;
67   }
68 
69  private:
70   const TensorShape shape_;
71   const DataType type_;
72   // We maintain a mapping from the slice string to the slice information.
73   std::unordered_map<string, SliceInfo> slices_;
74 
75   // Minimal slice which contains all presented slices. Used for speeding up
76   // overlap check when slices are being added consequently.
77   TensorSlice slices_hull_;
78 };
79 
80 // Registers "slice" in the TensorSliceSet stored in "tensor_slices", under key
81 // "name".  Other arguments are used for validations.  Does not modify the map
82 // or its values on non-OK.
83 // REQUIRES: tensor_slices != nullptr
84 Status RegisterTensorSlice(
85     const string& name, const TensorShape& shape, DataType type,
86     const string& tag, const TensorSlice& slice,
87     std::unordered_map<string, TensorSliceSet*>* tensor_slices);
88 
89 }  // namespace checkpoint
90 
91 }  // namespace tensorflow
92 
93 #endif  // TENSORFLOW_CORE_UTIL_TENSOR_SLICE_SET_H_
94