• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 import ImageClassification
10 import MobileNetClassifier
11 import SwiftUI
12 
13 enum Mode: String, CaseIterable {
14   case xnnpack = "XNNPACK"
15   case coreML = "Core ML"
16   case mps = "MPS"
17 }
18 
19 class ClassificationController: ObservableObject {
20   @AppStorage("mode") var mode: Mode = .xnnpack
21   @Published var classifications: [Classification] = []
22   @Published var elapsedTime: TimeInterval = 0.0
23   @Published var isRunning = false
24 
25   private let queue = DispatchQueue(label: "org.pytorch.executorch.demo", qos: .userInitiated)
26   private var classifier: ImageClassification?
27   private var currentMode: Mode = .xnnpack
28 
classifynull29   func classify(_ image: UIImage) {
30     guard !isRunning else {
31       print("Dropping frame")
32       return
33     }
34     isRunning = true
35 
36     if currentMode != mode {
37       currentMode = mode
38       classifier = nil
39     }
40     queue.async {
41       var classifications: [Classification] = []
42       var elapsedTime: TimeInterval = -1
43       do {
44         if self.classifier == nil {
45           self.classifier = try self.createClassifier(for: self.currentMode)
46         }
47         let startTime = CFAbsoluteTimeGetCurrent()
48         classifications = try self.classifier?.classify(image: image) ?? []
49         elapsedTime = (CFAbsoluteTimeGetCurrent() - startTime) * 1000
50       } catch {
51         print("Error classifying image: \(error)")
52       }
53       DispatchQueue.main.async {
54         self.classifications = classifications
55         self.elapsedTime = elapsedTime
56         self.isRunning = false
57       }
58     }
59   }
60 
createClassifiernull61   private func createClassifier(for mode: Mode) throws -> ImageClassification? {
62     let modelFileName: String
63     switch mode {
64     case .coreML:
65       modelFileName = "mv3_coreml_all"
66     case .mps:
67       modelFileName = "mv3_mps_float16"
68     case .xnnpack:
69       modelFileName = "mv3_xnnpack_fp32"
70     }
71     guard let modelFilePath = Bundle.main.path(forResource: modelFileName, ofType: "pte"),
72           let labelsFilePath = Bundle.main.path(forResource: "imagenet_classes", ofType: "txt")
73     else { return nil }
74     return try MobileNetClassifier(modelFilePath: modelFilePath, labelsFilePath: labelsFilePath)
75   }
76 }
77