1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 import ImageClassification 10 import MobileNetClassifier 11 import SwiftUI 12 13 enum Mode: String, CaseIterable { 14 case xnnpack = "XNNPACK" 15 case coreML = "Core ML" 16 case mps = "MPS" 17 } 18 19 class ClassificationController: ObservableObject { 20 @AppStorage("mode") var mode: Mode = .xnnpack 21 @Published var classifications: [Classification] = [] 22 @Published var elapsedTime: TimeInterval = 0.0 23 @Published var isRunning = false 24 25 private let queue = DispatchQueue(label: "org.pytorch.executorch.demo", qos: .userInitiated) 26 private var classifier: ImageClassification? 27 private var currentMode: Mode = .xnnpack 28 classifynull29 func classify(_ image: UIImage) { 30 guard !isRunning else { 31 print("Dropping frame") 32 return 33 } 34 isRunning = true 35 36 if currentMode != mode { 37 currentMode = mode 38 classifier = nil 39 } 40 queue.async { 41 var classifications: [Classification] = [] 42 var elapsedTime: TimeInterval = -1 43 do { 44 if self.classifier == nil { 45 self.classifier = try self.createClassifier(for: self.currentMode) 46 } 47 let startTime = CFAbsoluteTimeGetCurrent() 48 classifications = try self.classifier?.classify(image: image) ?? [] 49 elapsedTime = (CFAbsoluteTimeGetCurrent() - startTime) * 1000 50 } catch { 51 print("Error classifying image: \(error)") 52 } 53 DispatchQueue.main.async { 54 self.classifications = classifications 55 self.elapsedTime = elapsedTime 56 self.isRunning = false 57 } 58 } 59 } 60 createClassifiernull61 private func createClassifier(for mode: Mode) throws -> ImageClassification? { 62 let modelFileName: String 63 switch mode { 64 case .coreML: 65 modelFileName = "mv3_coreml_all" 66 case .mps: 67 modelFileName = "mv3_mps_float16" 68 case .xnnpack: 69 modelFileName = "mv3_xnnpack_fp32" 70 } 71 guard let modelFilePath = Bundle.main.path(forResource: modelFileName, ofType: "pte"), 72 let labelsFilePath = Bundle.main.path(forResource: "imagenet_classes", ofType: "txt") 73 else { return nil } 74 return try MobileNetClassifier(modelFilePath: modelFilePath, labelsFilePath: labelsFilePath) 75 } 76 } 77