• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_SHIM_SHAPE_H_
16 #define TENSORFLOW_LITE_KERNELS_SHIM_SHAPE_H_
17 
18 #include <initializer_list>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/status/statusor.h"
23 #include "absl/types/span.h"
24 
25 namespace tflite {
26 namespace shim {
27 
28 // Shape of a tensor. When unset it means the rank is unknown. Individual dims
29 // can also be unknown.
30 class Shape {
31  public:
32   using ValueType = std::vector<int>;
33 
34   Shape() = default;
35   Shape(const Shape& o) = default;
36   Shape(Shape&& o) = default;
37   Shape& operator=(const Shape& o) = default;
38   Shape& operator=(Shape&& o) = default;
39 
40   // Ctors
Shape(const std::initializer_list<int> & o)41   Shape(const std::initializer_list<int>& o) : value_(o), has_value_(true) {}
42   template <typename... Args>
Shape(Args &&...args)43   explicit Shape(Args&&... args)  // forward ctor args to that of std::vector
44       : value_(std::forward<Args>(args)...), has_value_(true) {}
Shape(const absl::Span<int> value)45   explicit Shape(const absl::Span<int> value)
46       : value_(value.data(), value.data() + value.size()), has_value_(true) {}
47 
48   // Accessors
has_value()49   inline bool has_value() const { return has_value_; }
value()50   inline ValueType& value() { return value_; }
value()51   inline const ValueType& value() const { return value_; }
52   ValueType* operator->() { return &value_; }
53   const ValueType* operator->() const { return &value_; }
54   ValueType& operator*() { return value_; }
55   const ValueType& operator*() const { return value_; }
56   // Get the specified dimension if known
57   int Dim(const int idx) const;
58 
59   // Returns the rank of the shape
Rank()60   const int Rank() const { return has_value_ ? value_.size() : kUnknownRank; }
61 
62   // Whether all the dimensions of the shape are known
63   bool FullyDefined() const;
64 
65   // Pretty printer
66   std::string ToString() const;
67 
68   // Adds two dimension taking into account unknown dims.
69   static int AddDims(const int dim1, const int dim2);
70 
71   // Comparison
72 
73   // Strict equality of the shapes. Unknown dims or rank on one side will
74   // result in false
75   bool operator==(const Shape& rhs) const;
76   bool operator!=(const Shape& rhs) const;
77 
78   // Compatibility of the shapes. If there are two known and incompatible
79   // dimensions it returns false
80   bool Compatible(const Shape& rhs) const;
81 
82   // The value for unknown dimensions and rank. There are static_asserts to
83   // ensure this matches the one defined in ::tensorflow namespace
84   static constexpr int kUnknownDim = -1;
85   static constexpr int kUnknownRank = -1;
86 
87  private:
88   ValueType value_;
89   bool has_value_ = false;
90 };
91 using ShapeOr = absl::StatusOr<Shape>;
92 
93 }  // namespace shim
94 }  // namespace tflite
95 
96 #endif  // TENSORFLOW_LITE_KERNELS_SHIM_SHAPE_H_
97