• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_CORE_SUBGRAPH_H_
16 #define TENSORFLOW_LITE_CORE_SUBGRAPH_H_
17 
18 #include <stdarg.h>
19 #include <stddef.h>
20 
21 #include <cstdint>
22 #include <cstdlib>
23 #include <map>
24 #include <memory>
25 #include <utility>
26 #include <vector>
27 
28 #include "tensorflow/lite/allocation.h"
29 #include "tensorflow/lite/c/common.h"
30 #include "tensorflow/lite/core/api/error_reporter.h"
31 #include "tensorflow/lite/core/api/profiler.h"
32 #include "tensorflow/lite/core/macros.h"
33 #include "tensorflow/lite/experimental/resource/initialization_status.h"
34 #include "tensorflow/lite/experimental/resource/resource_base.h"
35 #include "tensorflow/lite/graph_info.h"
36 #include "tensorflow/lite/memory_planner.h"
37 #include "tensorflow/lite/util.h"
38 
39 namespace tflite {
40 
41 class SingleOpModel;  // Class for friend declarations.
42 
43 namespace delegates {
44 namespace test_utils {
45 class TestDelegate;  // Class for friend declarations.
46 }  // namespace test_utils
47 }  // namespace delegates
48 
49 class Subgraph {
50  public:
51   friend class Interpreter;
52   friend class SingleOpModel;
53 
54   Subgraph(ErrorReporter* error_reporter,
55            TfLiteExternalContext** external_contexts,
56            std::vector<std::unique_ptr<Subgraph>>* subgraphs,
57            resource::ResourceMap* resources,
58            resource::ResourceIDMap* resource_ids,
59            resource::InitializationStatusMap* initialization_status_map);
60 
61   Subgraph(const Subgraph&) = delete;
62 
63   // Subgraphs should be movable but not copyable.
64   Subgraph(Subgraph&&) = default;
65   Subgraph& operator=(const Subgraph&) = delete;
66   virtual ~Subgraph();
67 
68   // Provide a list of tensor indexes that are inputs to the model.
69   // Each index is bound check and this modifies the consistent_ flag of the
70   // interpreter.
71   TfLiteStatus SetInputs(std::vector<int> inputs);
72 
73   // Provide a list of tensor indexes that are outputs to the model
74   // Each index is bound check and this modifies the consistent_ flag of the
75   // interpreter.
76   TfLiteStatus SetOutputs(std::vector<int> outputs);
77 
78   // Provide a list of tensor indexes that are variable tensors.
79   // Each index is bound check and this modifies the consistent_ flag of the
80   // interpreter.
81   TfLiteStatus SetVariables(std::vector<int> variables);
82 
83   // Adds a node with the given parameters and returns the index of the new
84   // node in `node_index` (optionally). Interpreter will take ownership of
85   // `builtin_data` and destroy it with `free`. Ownership of 'init_data'
86   // remains with the caller.
87   TfLiteStatus AddNodeWithParameters(const std::vector<int>& inputs,
88                                      const std::vector<int>& outputs,
89                                      const std::vector<int>& intermediates,
90                                      const char* init_data,
91                                      size_t init_data_size, void* builtin_data,
92                                      const TfLiteRegistration* registration,
93                                      int* node_index = nullptr);
94 
95   // Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries.
96   // The value pointed to by `first_new_tensor_index` will be set to the
97   // index of the first new tensor if `first_new_tensor_index` is non-null.
98   TfLiteStatus AddTensors(int tensors_to_add,
99                           int* first_new_tensor_index = nullptr);
100 
101   // Set description of inputs/outputs/data/fptrs for node `node_index`.
102   // This variant assumes an external buffer has been allocated of size
103   // bytes. The lifetime of buffer must be ensured to be greater or equal
104   // to Interpreter. `quantization` ownership is passed to the subgraph.
105   inline TfLiteStatus SetTensorParametersReadOnly(
106       int tensor_index, TfLiteType type, const char* name,
107       const std::vector<int>& dims, TfLiteQuantization quantization,
108       const char* buffer, size_t bytes, const Allocation* allocation = nullptr,
109       TfLiteSparsity* sparsity = nullptr) {
110     return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(),
111                                        dims.data(), quantization, buffer, bytes,
112                                        allocation, sparsity);
113   }
114   TfLiteStatus SetTensorParametersReadOnly(
115       int tensor_index, TfLiteType type, const char* name, const size_t rank,
116       const int* dims, TfLiteQuantization quantization, const char* buffer,
117       size_t bytes, const Allocation* allocation = nullptr,
118       TfLiteSparsity* sparsity = nullptr);
119 
120   // Set description of inputs/outputs/data/fptrs for node `node_index`.
121   // This variant assumes an external buffer has been allocated of size
122   // bytes. The lifetime of buffer must be ensured to be greater or equal
123   // to Interpreter. `quantization` ownership is passed to the subgraph.
124   inline TfLiteStatus SetTensorParametersReadWrite(
125       int tensor_index, TfLiteType type, const char* name,
126       const std::vector<int>& dims, TfLiteQuantization quantization,
127       bool is_variable = false, const std::vector<int>& dims_signature = {}) {
128     if (dims_signature.empty()) {
129       return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
130                                           dims.data(), quantization,
131                                           is_variable);
132     }
133     return SetTensorParametersReadWrite(
134         tensor_index, type, name, dims.size(), dims.data(), quantization,
135         is_variable, dims_signature.size(), dims_signature.data());
136   }
137   TfLiteStatus SetTensorParametersReadWrite(
138       int tensor_index, TfLiteType type, const char* name, const size_t rank,
139       const int* dims, TfLiteQuantization quantization,
140       bool is_variable = false, const size_t rank_dims_signature = 0,
141       const int* dims_signature = nullptr);
142 
143   // Get a mutable tensor data structure.
tensor(int tensor_index)144   TfLiteTensor* tensor(int tensor_index) {
145     if (tensor_index < 0 ||
146         static_cast<size_t>(tensor_index) >= context_.tensors_size) {
147       return nullptr;
148     }
149     return &context_.tensors[tensor_index];
150   }
151 
152   // Get an immutable tensor data structure.
tensor(int tensor_index)153   const TfLiteTensor* tensor(int tensor_index) const {
154     if (tensor_index < 0 ||
155         static_cast<size_t>(tensor_index) >= context_.tensors_size) {
156       return nullptr;
157     }
158     return &context_.tensors[tensor_index];
159   }
160 
161   // Read only access to list of inputs.
inputs()162   std::vector<int>& inputs() { return inputs_; }
163 
164   // Read only access to list of inputs.
inputs()165   const std::vector<int>& inputs() const { return inputs_; }
166 
167   // Read only access to list of outputs.
outputs()168   std::vector<int>& outputs() { return outputs_; }
169 
170   // Read only access to list of outputs.
outputs()171   const std::vector<int>& outputs() const { return outputs_; }
172 
173   // Read only access to list of variable tensors.
variables()174   std::vector<int>& variables() { return variables_; }
175 
176   // Read only access to list of variable tensors.
variables()177   const std::vector<int>& variables() const { return variables_; }
178 
179   // WARNING: Experimental interface, subject to change.
180   // TODO(ycling): Move this function to an external context interface.
resources()181   resource::ResourceMap& resources() { return *resources_; }
182 
183   // WARNING: Experimental interface, subject to change.
184   // TODO(b/149099381): Move this function to an external context interface.
resource_ids()185   resource::ResourceIDMap& resource_ids() { return *resource_ids_; }
186 
187   // WARNING: Experimental interface, subject to change.
188   // TODO(b/149099381): Move this function to an external context interface.
initialization_status_map()189   resource::InitializationStatusMap& initialization_status_map() {
190     return *initialization_status_map_;
191   }
192 
tensors_size()193   size_t tensors_size() const { return tensors_.size(); }
194 
195   // Return the number of ops in the model.
nodes_size()196   size_t nodes_size() const { return nodes_and_registration_.size(); }
197 
198   // Return vector of node indices in the order of execution.
execution_plan()199   std::vector<int>& execution_plan() { return execution_plan_; }
200 
201   // Return read-only vector of node indices in the order of execution.
execution_plan()202   const std::vector<int>& execution_plan() const { return execution_plan_; }
203 
204   const std::vector<std::pair<TfLiteNode, TfLiteRegistration>>&
nodes_and_registration()205   nodes_and_registration() const {
206     return nodes_and_registration_;
207   }
208 
209   // Get a pointer to an operation and registration data structure if in bounds.
node_and_registration(int node_index)210   const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration(
211       int node_index) const {
212     if (node_index < 0 || static_cast<size_t>(node_index) >= nodes_size())
213       return nullptr;
214     return &nodes_and_registration_[node_index];
215   }
216 
217   // Change the dimensionality of a given tensor. Note, this is only acceptable
218   // for tensor indices that are inputs.
219   // Returns status of failure or success.
220   // TODO(aselle): Consider implementing ArraySlice equivalent to make this
221   //   more adept at accepting data without an extra copy. Use absl::ArraySlice
222   //   if our partners determine that dependency is acceptable.
223   TfLiteStatus ResizeInputTensor(int tensor_index,
224                                  const std::vector<int>& dims);
225 
226   // WARNING: Experimental interface, subject to change
227   // Change the dimensionality of a given tensor. This is only acceptable for
228   // tensor indices that are inputs or variables. Only unknown dimensions can be
229   // resized with this function. Unknown dimensions are indicated as `-1` in the
230   // `dims_signature` attribute of a `TfLiteTensor`. Returns status of failure
231   // or success.
232   TfLiteStatus ResizeInputTensorStrict(int tensor_index,
233                                        const std::vector<int>& dims);
234 
235   // This releases memory held by non-persistent tensors. It does NOT re-perform
236   // memory planning.
237   // AllocateTensors needs to be called before next invocation.
238   TfLiteStatus ReleaseNonPersistentMemory();
239 
240   // Update allocations for all tensors. This will redim dependent tensors using
241   // the input tensor dimensionality as given. This is relatively expensive.
242   // If you know that your sizes are not changing, you need not call this.
243   // Returns status of success or failure.
244   TfLiteStatus AllocateTensors();
245 
246   // Invoke the subgraph (run the whole graph in dependency order).
247   //
248   // NOTE: It is possible that the interpreter is not in a ready state
249   // to evaluate (i.e. if a ResizeTensor() has been performed without an
250   // AllocateTensors().
251   // Returns status of success or failure.
252   TfLiteStatus Invoke();
253 
254   // Entry point for C node plugin API to report an error.
255   void ReportError(const char* format, ...);
256 
257   // Return the subgraph specific context.
context()258   TfLiteContext* context() { return &context_; }
259 
260   // Set the value of an external context.
261   void SetExternalContext(TfLiteExternalContextType type,
262                           TfLiteExternalContext* ctx);
263   // Get the half precision flag.
264   // WARNING: This is an experimental API and subject to change.
GetAllowFp16PrecisionForFp32()265   bool GetAllowFp16PrecisionForFp32() const {
266     return context_.allow_fp32_relax_to_fp16;
267   }
268 
269   // Sets the cancellation function pointer in order to cancel a request in the
270   // middle of a call to Invoke(). The interpreter queries this function during
271   // inference, between op invocations; when it returns true, the interpreter
272   // will abort execution and return `kTfLiteError`. The `data` parameter
273   // contains any data used by the cancellation function, and if non-null,
274   // remains owned by the caller.
275   // WARNING: This is an experimental API and subject to change.
276   void SetCancellationFunction(void* data, bool (*check_cancelled_func)(void*));
277 
278   // Ensure the data in `tensor.data` is readable. In case delegate is used,
279   // it might require to copy the data from delegate buffer to raw memory.
280   // WARNING: This is an experimental API and subject to change.
281   // TODO(b/119495520): make this private when refactoring complete.
EnsureTensorDataIsReadable(int tensor_index)282   TfLiteStatus EnsureTensorDataIsReadable(int tensor_index) {
283     TfLiteTensor* t = &tensors_[tensor_index];
284     TF_LITE_ENSURE(&context_, t != nullptr);
285     if (t->data_is_stale) {
286       TF_LITE_ENSURE(&context_, t->delegate != nullptr);
287       TF_LITE_ENSURE(&context_, t->buffer_handle != kTfLiteNullBufferHandle);
288       TF_LITE_ENSURE(&context_, t->delegate->CopyFromBufferHandle != nullptr);
289       TF_LITE_ENSURE_STATUS(t->delegate->CopyFromBufferHandle(
290           &context_, t->delegate, t->buffer_handle, t));
291       t->data_is_stale = false;
292     }
293     return kTfLiteOk;
294   }
295 
296   // The default capacity of `tensors_` vector.
297   static constexpr int kTensorsReservedCapacity = 128;
298   // The capacity headroom of `tensors_` vector before calling ops'
299   // `prepare` and `invoke` function. In these functions, it's guaranteed
300   // allocating up to `kTensorsCapacityHeadroom` more tensors won't invalidate
301   // pointers to existing tensors.
302   static constexpr int kTensorsCapacityHeadroom = 16;
303 
304   // Reset all variable tensors to the default value.
305   // If a variable tensor doesn't have a buffer, reset it to zero.
306   // TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it
307   // to the value of the buffer.
308   // WARNING: This is an experimental API and subject to change.
309   TfLiteStatus ResetVariableTensors();
310 
SetProfiler(Profiler * profiler,int associated_subgraph_idx)311   void SetProfiler(Profiler* profiler, int associated_subgraph_idx) {
312     if (!profiler) {
313       profiler_.reset(nullptr);
314       context_.profiler = nullptr;
315     } else {
316       profiler_.reset(
317           new SubgraphAwareProfiler(profiler, associated_subgraph_idx));
318       context_.profiler = profiler_.get();
319     }
320   }
321 
GetProfiler()322   Profiler* GetProfiler() { return profiler_.get(); }
323 
324   // Returns a pointer to vector of subgraphs.
325   // WARNING: This is an experimental API and subject to change.
GetSubgraphs()326   std::vector<std::unique_ptr<Subgraph>>* GetSubgraphs() { return subgraphs_; }
327 
328   // True if all tensors in the graph has static size after calling
329   // `AllocateTensors` function.
330   // Before `AllocateTensors` is called, this will always return true;
HasDynamicTensors()331   bool HasDynamicTensors() { return has_dynamic_tensors_; }
332 
333   // Assigns (or reassigns) a custom memory allocation for the given tensor.
334   // `flags` is a bitmask, see TfLiteCustomAllocationFlags.
335   // The runtime does NOT take ownership of the underlying memory.
336   //
337   // NOTE: User needs to call AllocateTensors() after this. In case of input
338   // resizing, buffers will be checked for required data size during
339   // AllocateTensors().
340   //
341   // Parameters should satisfy the following conditions:
342   // 1. tensor->allocation_type == kTfLiteArenaRw or kTfLiteArenaRwPersistent
343   //    In general, this is true for I/O tensors & variable tensors.
344   // 2. allocation->data has the appropriate permissions for runtime access
345   //    (Read-only for inputs, Read-Write for others), and outlives Interpreter.
346   // 3. allocation->bytes >= tensor->bytes.
347   //    This condition is checked again if any tensors are resized.
348   // 4. allocation->data should be aligned to kDefaultTensorAlignment
349   //    defined in lite/util.h. (Currently 64 bytes)
350   //    This check is skipped if kTfLiteCustomAllocationFlagsSkipAlignCheck is
351   //    set through `flags`.
352   // TODO(b/182215910): Expand on this documentation in a g3doc.
353   //
354   // WARNING: This is an experimental interface that is subject to change.
355   TfLiteStatus SetCustomAllocationForTensor(
356       int tensor_index, const TfLiteCustomAllocation& allocation,
357       int64_t flags = kTfLiteCustomAllocationFlagsNone);
358 
359   void SetName(const char* name);
360   const std::string& GetName() const;
361 
362   // WARNING: This is an experimental API and subject to change.
363   // Dumps debugging info by the underlying memory planner.
364   // Note: to have minimal binary increase caused by this debug info dump for
365   // the TfLite library and allow users to plug-in their own memory planner
366   // debugger, we have utilized weak symbols to meet these two requirements. By
367   // default, there is no debugging info dumped. However, if the TfLite-provided
368   // lite:simple_memory_arena_debug_dump (i.e. containing the strong defintion)
369   // is linked to the program, calling this function will output memory usage
370   // information about tenosrs and ops.
371   void DumpMemoryPlannerDebugInfo() const;
372 
373  private:
374   friend class InterpreterBuilder;
375   friend class TestDelegate;
376   // SubgraphAwareProfiler wraps an actual TFLite profiler, such as a
377   // BufferedProfiler instance, and takes care of event profiling/tracing in a
378   // certain subgraph.
379   class SubgraphAwareProfiler : public Profiler {
380    public:
381     // Constructor should be called with the non-nullptr profiler argument.
SubgraphAwareProfiler(Profiler * profiler,int64_t subgraph_index)382     SubgraphAwareProfiler(Profiler* profiler, int64_t subgraph_index)
383         : profiler_(profiler), subgraph_index_(subgraph_index) {}
~SubgraphAwareProfiler()384     ~SubgraphAwareProfiler() override {}
385 
BeginEvent(const char * tag,EventType event_type,int64_t event_metadata1,int64_t event_metadata2)386     uint32_t BeginEvent(const char* tag, EventType event_type,
387                         int64_t event_metadata1,
388                         int64_t event_metadata2) override {
389       if (!profiler_) return 0;
390       return profiler_->BeginEvent(tag, event_type, event_metadata1,
391                                    subgraph_index_);
392     }
393 
EndEvent(uint32_t event_handle)394     void EndEvent(uint32_t event_handle) override {
395       if (!profiler_) return;
396       profiler_->EndEvent(event_handle);
397     }
398 
EndEvent(uint32_t event_handle,int64_t event_metadata1,int64_t event_metadata2)399     void EndEvent(uint32_t event_handle, int64_t event_metadata1,
400                   int64_t event_metadata2) override {
401       if (!profiler_) return;
402       profiler_->EndEvent(event_handle, event_metadata1, event_metadata2);
403     }
404 
AddEvent(const char * tag,EventType event_type,uint64_t start,uint64_t end,int64_t event_metadata1,int64_t event_metadata2)405     void AddEvent(const char* tag, EventType event_type, uint64_t start,
406                   uint64_t end, int64_t event_metadata1,
407                   int64_t event_metadata2) override {
408       if (!profiler_) return;
409       profiler_->AddEvent(tag, event_type, start, end, event_metadata1,
410                           subgraph_index_);
411     }
412 
413    private:
414     // Not own the memory.
415     Profiler* const profiler_;
416     const int64_t subgraph_index_;
417   };
418 
419   // Ensure the internal node storage memory allocates at least `count`
420   // spots for node. NOTE, this doesn't actually add operators. This is an
421   // efficiency optimization that is subject to change.
422   // Note: Only used during initialization.
423   void ReserveNodes(int count);
424 
425   // Overrides execution plan. This bounds checks indices sent in.
426   // Note: Only used during initialization.
427   TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan);
428 
429   // Prevent 'context_' from accessing functions that are only available to
430   // delegated kernels.
431   void SwitchToKernelContext();
432 
433   // Add delegate-only functions to 'context_'.
434   void SwitchToDelegateContext();
435 
436   // Give 'op_reg' a chance to initialize itself using the contents of
437   // 'buffer'.
OpInit(const TfLiteRegistration & op_reg,const char * buffer,size_t length)438   void* OpInit(const TfLiteRegistration& op_reg, const char* buffer,
439                size_t length) {
440     if (op_reg.init == nullptr) return nullptr;
441     return op_reg.init(&context_, buffer, length);
442   }
443 
444   // Let 'op_reg' release any memory it might have allocated via 'OpInit'.
OpFree(const TfLiteRegistration & op_reg,void * buffer)445   void OpFree(const TfLiteRegistration& op_reg, void* buffer) {
446     if (op_reg.free == nullptr) return;
447     if (buffer) {
448       op_reg.free(&context_, buffer);
449     }
450   }
451 
452   // Prepare the given 'node' for execution.
453   TfLiteStatus OpPrepare(const TfLiteRegistration& op_reg, TfLiteNode* node);
454 
455   // Invoke the operator represented by 'node'.
OpInvoke(const TfLiteRegistration & op_reg,TfLiteNode * node)456   TfLiteStatus OpInvoke(const TfLiteRegistration& op_reg, TfLiteNode* node) {
457     if (op_reg.invoke == nullptr) return kTfLiteError;
458     return op_reg.invoke(&context_, node);
459   }
460 
461   // Call OpPrepare() for as many ops as possible, allocating memory for their
462   // tensors. If an op containing dynamic tensors is found, preparation will be
463   // postponed until this function is called again. This allows the interpreter
464   // to wait until Invoke() to resolve the sizes of dynamic tensors.
465   TfLiteStatus PrepareOpsAndTensors();
466 
467   // Call OpPrepare() for all ops starting at 'first_node'. Stop when a
468   // dynamic tensors is found or all ops have been prepared. Fill
469   // 'last_node_prepared' with the id of the op containing dynamic tensors, or
470   // the last in the graph.
471   TfLiteStatus PrepareOpsStartingAt(int first_execution_plan_index,
472                                     const std::vector<int>& execution_plan,
473                                     int* last_execution_plan_index_prepared);
474 
475   // Tensors needed by the interpreter. Use `AddTensors` to add more blank
476   // tensor entries. Note, `tensors_.data()` needs to be synchronized to the
477   // `context_` whenever this std::vector is reallocated. Currently this
478   // only happens in `AddTensors()`.
479   std::vector<TfLiteTensor> tensors_;
480 
481   // Check if an array of tensor indices are valid with respect to the Tensor
482   // array.
483   // NOTE: this changes consistent_ to be false if indices are out of bounds.
484   TfLiteStatus CheckTensorIndices(const char* label, const int* indices,
485                                   int length);
486 
487   // Check that the input indices and the output indices don't overlap.
488   // This is needed because same tensor must not be used both as input and
489   // output for an operator.
490   // NOTE: this changes consistent_ to be false if indices are out of bounds.
491   TfLiteStatus CheckInputAndOutputForOverlap(const int* input_indices,
492                                              int num_inputs,
493                                              const int* output_indices,
494                                              int num_outputs);
495 
496   // Compute the number of bytes required to represent a tensor with dimensions
497   // specified by the array dims (of length dims_size). Returns the status code
498   // and bytes.
499   TfLiteStatus BytesRequired(TfLiteType type, const int* dims, size_t dims_size,
500                              size_t* bytes);
501 
502   // Request an tensor be resized implementation. If the given tensor is of
503   // type kTfLiteDynamic it will also be allocated new memory.
504   TfLiteStatus ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size);
505 
506   // Report a detailed error string (will be printed to stderr).
507   // TODO(aselle): allow user of class to provide alternative destinations.
508   void ReportErrorImpl(const char* format, va_list args);
509 
510   // Entry point for C node plugin API to request an tensor be resized.
511   static TfLiteStatus ResizeTensor(TfLiteContext* context, TfLiteTensor* tensor,
512                                    TfLiteIntArray* new_size);
513   // Entry point for C node plugin API to report an error.
514   static void ReportErrorC(TfLiteContext* context, const char* format, ...);
515 
516   // Entry point for C node plugin API to add new tensors.
517   static TfLiteStatus AddTensors(TfLiteContext* context, int tensors_to_add,
518                                  int* first_new_tensor_index);
519 
520   // WARNING: This is an experimental API and subject to change.
521   // Entry point for C API ReplaceNodeSubsetsWithDelegateKernels
522   static TfLiteStatus ReplaceNodeSubsetsWithDelegateKernels(
523       TfLiteContext* context, TfLiteRegistration registration,
524       const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate);
525 
526   // Update the execution graph to replace some of the nodes with stub
527   // nodes. Specifically any node index that has `nodes[index]==1` will be
528   // slated for replacement with a delegate kernel specified by registration.
529   // Ownership of 'nodes_to_replace' and 'delegate' remains with the caller.
530   // WARNING: This is an experimental interface that is subject to change.
531   TfLiteStatus ReplaceNodeSubsetsWithDelegateKernels(
532       TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace,
533       TfLiteDelegate* delegate);
534 
535   // WARNING: This is an experimental interface that is subject to change.
536   // Gets the internal pointer to a TensorFlow lite node by node_index.
537   TfLiteStatus GetNodeAndRegistration(int node_index, TfLiteNode** node,
538                                       TfLiteRegistration** registration);
539 
540   // WARNING: This is an experimental interface that is subject to change.
541   // Entry point for C node plugin API to get a node by index.
542   static TfLiteStatus GetNodeAndRegistration(struct TfLiteContext*,
543                                              int node_index, TfLiteNode** node,
544                                              TfLiteRegistration** registration);
545 
546   // WARNING: This is an experimental interface that is subject to change.
547   // Gets an TfLiteIntArray* representing the execution plan. The interpreter
548   // owns this memory and it is only guaranteed to exist during the invocation
549   // of the delegate prepare.
550   TfLiteStatus GetExecutionPlan(TfLiteIntArray** execution_plan);
551 
552   // WARNING: This is an experimental interface that is subject to change.
553   // Entry point for C node plugin API to get the execution plan.
554   static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context,
555                                        TfLiteIntArray** execution_plan);
556 
557   // WARNING: This is an experimental interface that is subject to change.
558   // Provides a preview of post-delegation partitioning. Each
559   // TfLiteDelegateParams in the referenced array corresponds to one instance of
560   // the delegate kernel.
561   // nodes_to_replace should point to a valid array. partition_params_array &
562   // num_partitions should be non-null.
563   // Memory allocated by this method is automatically released with another call
564   // to PreviewDelegateParitioning, or after TfLiteDelegate::Prepare is done.
565   TfLiteStatus PreviewDelegatePartitioning(
566       const TfLiteIntArray* nodes_to_replace,
567       TfLiteDelegateParams** partition_params_array, int* num_partitions);
568 
569   // WARNING: This is an experimental interface that is subject to change.
570   // Entry point for C node plugin API to preview delegation partitioning.
571   static TfLiteStatus PreviewDelegatePartitioning(
572       struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
573       TfLiteDelegateParams** partition_params_array, int* num_partitions);
574 
575   // Retrieves named metadata from the TFLite model. Returns kTfLiteOk if
576   // metadata is successfully obtained.
577   // See the Metadata table in TFLite schema.
578   TfLiteStatus GetModelMetadata(const char* name, const char** ptr,
579                                 size_t* bytes);
580 
581   // Entry point for C node plugin API to get model metadata based on name.
582   static TfLiteStatus GetModelMetadata(const struct TfLiteContext* context,
583                                        const char* name, const char** ptr,
584                                        size_t* bytes);
585 
586   // Used to clear partitioning_preview_cache_, in case
587   // PreviewDelegatePartitioning was called.
588   void FreeDelegatePartitioningData();
589 
590   // Retrieve an existing external context by type.
591   TfLiteExternalContext* GetExternalContext(TfLiteExternalContextType type);
592   static TfLiteExternalContext* GetExternalContext(
593       struct TfLiteContext* context, TfLiteExternalContextType type);
594 
595   // Set the value of an external context.
596   static void SetExternalContext(struct TfLiteContext* context,
597                                  TfLiteExternalContextType type,
598                                  TfLiteExternalContext* ctx);
599 
600   // WARNING: This is an experimental API and subject to change.
601   // Allow a delegate to look at the graph and modify the graph to handle
602   // parts of the graph themselves. After this is called, the graph may
603   // contain new nodes that replace 1 more nodes.
604   // NOTE: If tensors were allocated prior to delegate application, they will
605   // be reallocated if the graph was modified (i.e., the caller does *not* need
606   // to explicitly call |AllocateTensors()| again). If tensors were unallocated,
607   // they will remain unallocated after delegate application.
608   // Returns one of the following status codes:
609   // 1. kTfLiteOk: Delegation succeeded
610   // 2. kTfLiteDelegateError: Delegation failed due to an error *in the
611   // delegate*, or the delegate parameter was null. The Subgraph has been
612   // restored to its pre-delegation state.
613   // NOTE: This reverts all delegates previously applied to the Subgraph.
614   // 3. kTfLiteApplicationError : Delegation failed to be applied due to the
615   // incompatibility with the TfLite runtime, e.g., the model graph is already
616   // immutable when applying the delegate. However, the Subgraph is still in a
617   // invokable state.
618   // 4. kTfLiteError: Unexpected/runtime failure.
619   TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate);
620 
621   // This un-applies all delegates that have been applied till now, but retains
622   // pointers to them.
623   // The old execution plan and nodes are restored.
624   TfLiteStatus UndoAllDelegates();
625 
626   // This re-applies all delegates that were undone.
627   // Does nothing if UndoAllDelegates wasn't previously called.
628   TfLiteStatus RedoAllDelegates();
629 
630   // This removes all delegates.
631   // The old execution plan and nodes are restored. The graph is invokable
632   // afterwards.
633   TfLiteStatus RemoveAllDelegates();
634 
635   // Returns true if the subgraph has delegates applied.
636   bool HasDelegates();
637 
638   // Cleanups up data reserved for the given node. Does not remove the {node,
639   // registration} pair from nodes_and_registrations_.
640   void CleanupNode(int node_index);
641 
642   // Ensures that `tensors_` has at least `kTensorsCapacityHeadroom` extra
643   // capacity. Calling this function may invalidate existing pointers to
644   // tensors. After calling this function, adding `kTensorsCapacityHeadroom`
645   // more tensors won't invalidate the pointer to existing tensors.
646   void EnsureTensorsVectorCapacity();
647 
648   // Ensures the memory required is planned and allocated.
649   TfLiteStatus EnsureMemoryAllocations();
650 
651   // Returns true if cancellation function returns true.
652   bool IsCancelled();
653 
654   // Enables preserving intermediates for debugging.
655   TfLiteStatus PreserveAllTensorsExperimental();
656 
657   // Returns true if 'node' could have side effect (e.g. stateful op).
658   // Note that any node that might update other tensors beside op's output
659   // are considered to have side effect.
660   // So control flow ops like 'If' and 'While' are considered to have
661   // side effect because they can have ops that have side effect in the
662   // condition and body subgraphs.
663   bool OpMightHaveSideEffect(const TfLiteNode* node,
664                              const TfLiteRegistration* registration) const;
665 
666   // Returns new GraphInfo object based on the current Subgraph.
667   std::unique_ptr<GraphInfo> CreateGraphInfo();
668 
669   // Store a ptr to the model metadata owned by the Interpreter.
670   // Since the lifetime of the Interpreter exceeds the Subgraph, metadata
671   // remains valid for the latter's lifetime.
672   // Also sets relevant fields on context_ based on known metadata.
673   TfLiteStatus SetMetadata(const std::map<std::string, std::string>* metadata);
674 
675   // The state of the Interpreter.
676   enum State {
677     // The interpreter isn't ready to be invoked.
678     // `AllocateTensor` need to be called to enter an invokable state.
679     kStateUninvokable = 0,
680     // The interpreter is ready to be invoked.
681     kStateInvokable,
682     // The interpreter is ready to be invoked, and graph can't be further
683     // modified. The interpreter will enter this state when calling
684     // `ModifyGraphWithDelegate` and the delegate doesn't support dynamic
685     // tensors.
686     kStateInvokableAndImmutable,
687   };
688   State state_ = kStateUninvokable;
689 
690   // A pure C data structure used to communicate with the pure C plugin
691   // interface. To avoid copying tensor metadata, this is also the definitive
692   // structure to store tensors.
693   TfLiteContext context_ = {};
694 
695   // A pointer to the external contexts (kTfLiteMaxExternalContexts) array that
696   // sits inside the associated TFLite interpreter instance.
697   TfLiteExternalContext** external_contexts_;
698 
699   // Node inputs/outputs are stored in TfLiteNode and TfLiteRegistration stores
700   // function pointers to actual implementation.
701   // Nodes should appear in the order in which they are instantiated at runtime.
702   // Delegated nodes are appended after all the original ones.
703   std::vector<std::pair<TfLiteNode, TfLiteRegistration>>
704       nodes_and_registration_;
705 
706   // Whether the model is consistent. That is to say if the inputs and outputs
707   // of every node and the global inputs and outputs are valid indexes into
708   // the tensor array.
709   bool consistent_ = true;
710 
711   // Array of indices representing the tensors that are inputs to the
712   // interpreter.
713   std::vector<int> inputs_;
714 
715   // Array of indices representing the tensors that are outputs to the
716   // interpreter.
717   std::vector<int> outputs_;
718 
719   // Array of indices representing the tensors that are variable tensors.
720   std::vector<int> variables_;
721 
722   // The error reporter delegate that tflite will forward queries errors to.
723   ErrorReporter* error_reporter_;
724 
725   // Index of the next node to prepare.
726   // During Invoke(), Interpreter will allocate input tensors first, which are
727   // known to be fixed size. Then it will allocate outputs from nodes as many
728   // as possible. When there is a node that produces dynamic sized tensor.
729   // Interpreter will stop allocating tensors, set the value of next allocate
730   // node id, and execute the node to generate the output tensor before continue
731   // to allocate successors. This process repeats until all nodes are executed.
732   // NOTE: this relies on the order of nodes that is in topological order.
733   int next_execution_plan_index_to_prepare_;
734 
735   // Only used in cases where a delegate supporting dynamic tensors is applied.
736   // This helps prepare the original execution before the post-delegation one,
737   // so that tensor shapes propagate.
738   int next_original_execution_plan_index_to_prepare_;
739 
740   // This is similar to `next_execution_plan_index_to_prepare_`, but it tracks
741   // which nodes' allocation is planned with the arena planner.
742   //
743   // This is a workaround for b/127354079. It shouldn't be necessary if
744   // ArenaPlanner can "rewind" to a specific point.
745   // TODO(b/127354079): Improve ArenaPlanner and remove this mechanism.
746   int next_execution_plan_index_to_plan_allocation_;
747 
748   // WARNING: This is an experimental interface that is subject to change.
749   // This is a list of node indices (to index into nodes_and_registration).
750   // This represents a valid topological sort (dependency ordered) execution
751   // plan. In particular, it is valid for this ordering to contain only a
752   // subset of the node indices.
753   std::vector<int> execution_plan_;
754 
755   // This is a copy of the first execution_plan_ before any delegates were
756   // applied. It is empty if no delegates were applied to this Subgraph.
757   std::vector<int> pre_delegation_execution_plan_;
758 
759   // Contains a list of delegates applied by the user so far, in order.
760   std::vector<TfLiteDelegate*> delegates_applied_;
761 
762   // Set to true if UndoAllDelegates was called, and to false during
763   // RedoAllDelegates.
764   bool delegates_undone_ = false;
765 
766   // In the future, we'd like a TfLiteIntArray compatible representation.
767   // TODO(aselle): replace execution_plan_ with this.
768   std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> plan_cache_;
769 
770   // Used by PreviewDelegateParitioning.
771   std::vector<TfLiteDelegateParams> partitioning_preview_cache_;
772 
773   std::unique_ptr<MemoryPlanner> memory_planner_;
774 
775   // Contains <tensor idx, custom allocation> pairs for all applicable tensors.
776   std::vector<std::pair<int, TfLiteCustomAllocation>> custom_allocations_;
777 
778   // Tracking bit for whether a tensor was resized in the course of an op
779   // invocation. This is a useful hint to ensure that dynamic tensor outputs
780   // trigger downstream reallocation after op invocation.
781   bool tensor_resized_since_op_invoke_ = false;
782 
783   // Profiler for this interpreter instance.
784   std::unique_ptr<SubgraphAwareProfiler> profiler_;
785 
786   // A pointer to vector of subgraphs. The vector is owned by the interpreter.
787   std::vector<std::unique_ptr<Subgraph>>* subgraphs_ = nullptr;
788 
789   // True if not all tensors in the graph has static size after calling
790   // `PrepareOpsStartingAt` function (which is called by the `AllocateTensors`
791   // public function).
792   // The value is invalid before `PrepareOpStartingAt` is called.
793   bool has_dynamic_tensors_ = true;
794 
795   // Reference to cancellation function that can cancel a request in the middle
796   // of a call to Invoke(). When this function returns True, a kTfLiteError is
797   // thrown by Invoke().
798   bool (*check_cancelled_func_)(void*) = nullptr;
799 
800   // Reference to data used by the cancellation function in
801   // `check_cancelled_func_`.
802   void* cancellation_data_ = nullptr;
803 
804   // A map of resources. Owned by interpreter and shared by multiple subgraphs.
805   resource::ResourceMap* resources_ = nullptr;
806 
807   // A map of resources IDs. Owned by interpreter and shared by multiple
808   // subgraphs.
809   resource::ResourceIDMap* resource_ids_ = nullptr;
810 
811   // A map of initialization statuses, that indicate whether the intialization
812   // subgraph invocation is done or not.
813   resource::InitializationStatusMap* initialization_status_map_;
814 
815   // Name of the subgraph (analogous to function name).
816   std::string name_;
817 
818   // Whether memory planner should be instantiated to retain intermediates for
819   // debugging.
820   bool preserve_all_tensors_ = false;
821 
822   // Model-metadata owned by the Interpreter.
823   const std::map<std::string, std::string>* metadata_ = nullptr;
824 };
825 
826 }  // namespace tflite
827 #endif  // TENSORFLOW_LITE_CORE_SUBGRAPH_H_
828