• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_RUNTIME_H_
16 #define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_RUNTIME_H_
17 
18 #include <sys/types.h>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/c/eager/tfe_context_internal.h"
24 #include "tensorflow/c/tf_status_helper.h"
25 #include "tensorflow/c/tf_status_internal.h"
26 #include "tensorflow/cc/experimental/libtf/object.h"
27 #include "tensorflow/cc/experimental/libtf/value.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/statusor.h"
30 
31 namespace tf {
32 namespace libtf {
33 namespace runtime {
34 
35 /// @brief A runtime object capable of loading modules and executing functions.
36 ///
37 /// It is the responsibility of the owner of the Runtime to keep it alive longer
38 /// than all imported modules.
39 class Runtime : public Object {
40  public:
41   // TODO(b/191264214): Remove need for AbstractContext
42   explicit Runtime(tensorflow::AbstractContext* ctx);
43   /// @brief Loads the module indicated by `name` and returns it.
44   ///
45   /// @param name The name of the module / file path to load
46   /// @return An `Object` representing the module, if successful.  Otherwise, a
47   /// non-ok `absl::Status`.
48   tensorflow::StatusOr<Object> Load(const String& name);
49   // TODO(b/186787000): Loading a module with identically-named functions as
50   // a previously loaded module results in undefined behavior. This
51   // functionality will be supported in the future.
52 
53   // Create a host tensor and copy data into it.
54   //
55   // Raises an error if shape or dtype are incompatible with T.
56   // TODO(b/189458441): Update this when we decide on the representation of
57   // shape and dtype in this API.
58   // Disclaimer: This API is subject to change as we add support for creating
59   // device tensors b/187222691 and enable buffer re-use b/187223179.
60   // TODO(b/190715501): Make this available via a soft API as well.
61   template <class T>
62   tensorflow::StatusOr<Tensor> CreateHostTensor(absl::Span<const int64_t> shape,
63                                                 int dtype,
64                                                 absl::Span<const T> data);
65 };
66 
67 template <class T>
CreateHostTensor(absl::Span<const int64_t> shape,int dtype,absl::Span<const T> data)68 tensorflow::StatusOr<Tensor> Runtime::CreateHostTensor(
69     absl::Span<const int64_t> shape, int dtype, absl::Span<const T> data) {
70   size_t num_elements = 1;
71   for (int dim = 0; dim < shape.size(); dim++) {
72     if (shape[dim] < 0) {
73       return tensorflow::errors::InvalidArgument(absl::StrCat(
74           "Shape must be fully-defined, got: shape[", dim, "] = ", shape[dim]));
75     }
76     num_elements *= shape[dim];
77   }
78   if (data.size() != num_elements) {
79     return tensorflow::errors::InvalidArgument(absl::StrCat(
80         "Mismatched shape and data size: \n", "Shape num_elements: ",
81         num_elements, "\n", "Data size: ", data.size(), "\n"));
82   }
83   auto maybe_capsule = Get<internal::Capsule>(String("ctx"));
84   if (!maybe_capsule.status().ok()) {
85     return maybe_capsule.status();
86   }
87   auto capsule = maybe_capsule.ValueOrDie();
88   auto ctx = capsule.cast<tensorflow::ImmediateExecutionContext*>();
89   tensorflow::AbstractTensorPtr t(
90       ctx->CreateTensor(static_cast<tensorflow::DataType>(dtype), shape));
91   // TODO(srbs): This is still a weak check. Check that dtype and T are
92   // compatible.
93   if (t->ByteSize() != sizeof(T) * data.size()) {
94     return tensorflow::errors::InvalidArgument(absl::StrCat(
95         "Invalid number of bytes in data buffer\n", "Expected bytes: ",
96         t->ByteSize(), "\n", "Actual bytes: ", sizeof(T) * data.size()));
97   }
98   memcpy(t->Data(), data.data(), t->ByteSize());
99   return Tensor(Convert(TaggedValue(
100       impl::TaggedValueTensor(ctx->CreateLocalHandle(t.get()), false))));
101 }
102 
103 }  // namespace runtime
104 }  // namespace libtf
105 }  // namespace tf
106 
107 #endif  // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_RUNTIME_H_
108