1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_ 17 #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_ 18 19 #if (defined(__i386) || defined(_M_IX86) || defined(__x86_64__) || \ 20 defined(_M_X64)) 21 #define TFLITE_X86_PLATFORM 22 #endif 23 24 #include <memory> 25 26 #include "public/gemmlowp.h" 27 #include "ruy/context.h" // from @ruy 28 #include "tensorflow/lite/c/common.h" 29 #include "tensorflow/lite/external_cpu_backend_context.h" 30 31 namespace tflite { 32 33 class CpuBackendContext final : public TfLiteInternalBackendContext { 34 public: 35 static CpuBackendContext* GetFromContext(TfLiteContext* context); 36 37 CpuBackendContext(); 38 ~CpuBackendContext() override; 39 ruy_context()40 ruy::Context* ruy_context() const { return ruy_context_.get(); } 41 gemmlowp_context()42 gemmlowp::GemmContext* gemmlowp_context() const { 43 return gemmlowp_context_.get(); 44 } 45 46 // Sets the maximum-number-of-threads-to-use parameter, only as a means of 47 // passing around this information. 48 void SetMaxNumThreads(int max_num_threads) override; 49 max_num_threads()50 int max_num_threads() const { return max_num_threads_; } 51 52 void SetUseCaching(bool flag); 53 use_caching()54 bool use_caching() const { return use_caching_; } 55 ClearCaches()56 void ClearCaches() override { ruy_context_->ClearPrepackedCache(); } 57 58 bool HasAvxOrAbove(); 59 60 // Gemmlowp on x86 is a deprecated path but some clients may still use 61 // this path based on link time dependencies. 62 bool PreferGemmlowpOnX86(); 63 64 private: 65 // Copy the wrapper class for cpuinfo from Ruy. 66 class CpuInfo final { 67 public: CpuInfo()68 CpuInfo() {} 69 ~CpuInfo(); 70 71 // X86 features 72 bool Avx(); 73 bool Avx2Fma(); 74 bool Avx512(); 75 76 private: 77 enum class InitStatus { 78 kNotYetAttempted, 79 kInitialized, 80 kFailed, 81 }; 82 83 InitStatus init_status_ = InitStatus::kNotYetAttempted; 84 85 bool EnsureInitialized(); 86 InitStatus Initialize(); 87 CpuInfo(const CpuInfo&) = delete; 88 CpuInfo& operator=(const CpuInfo&) = delete; 89 }; 90 91 // To enable a smooth transition from the current direct usage 92 // of the underlying gemmlowp context to going through abstractions 93 // (see :cpu_backend_gemm), for now a CpuBackendContext always 94 // stores both a gemmlowp context and a ruy context. 95 // TODO(b/131416458): Once call sites all go through abstractions, 96 // elide what can be elided based on TFLITE_WITH_RUY. 97 const std::unique_ptr<ruy::Context> ruy_context_; 98 const std::unique_ptr<gemmlowp::GemmContext> gemmlowp_context_; 99 CpuInfo cpuinfo_; 100 101 // The maximum of threads used for parallelizing TfLite ops. However, 102 // cpu_backend_threadpool::Execute creates as many threads as it's 103 // asked to, regardless of this. Typically a call site would query 104 // cpu_backend_context->max_num_threads() and used that to determine 105 // the number of tasks to create and to give to 106 // cpu_backend_threadpool::Execute. 107 // 108 // This value also gets propagated to back-ends, where it plays the same 109 // information-only role. 110 int max_num_threads_; 111 // For matrix muliplications with constants parameters (i.e. weights), we can 112 // sometimes provide speedups by caching the "prepacked" data, for some 113 // additional memory cost. This flag permits the user to route all 114 // CpuBackendGem operations to a library that permits such an optimization 115 // (currently the Ruy library only). 116 bool use_caching_; 117 118 CpuBackendContext(const CpuBackendContext&) = delete; 119 }; 120 121 } // namespace tflite 122 123 #endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_ 124