Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions App/Sources/Coordinator/MainCoordinator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ extension MainCoordinator: MainCoordinatorDelegate {
Task {
let isModelDownloaded = await dependencyContainer.isWhisperModelDownloaded()
if isModelDownloaded {
dependencyContainer.preloadWhisperKit()
let viewModel = dependencyContainer.makeRecordingViewModel()
viewModel.coordinator = self
viewModel.alertCoordinator = self
Expand Down Expand Up @@ -210,14 +211,8 @@ extension MainCoordinator: ChaGokAlertCoordinatorDelegate {
// MARK: - DownloadWhisperCoordinatorDelegate

extension MainCoordinator: DownloadOnDeviceCoordinatorDelegate {
func dismissSheet(completion: Bool) {
if completion { // 모델 다운로드 완료 후
presenter.dismiss(animated: true) { [weak self] in
self?.dependencyContainer.preloadWhisperKit()
}
} else {
presenter.dismiss(animated: true)
}
func dismissSheet() {
presenter.dismiss(animated: true)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import Foundation
import HuggingFace
import MLXHuggingFace
import MLXLMCommon

/// Hugging Face Hub에서 모델 파일을 다운로드하기 위한 Downloader 구현체.
/// 컴파일 타임 매크로(#hubDownloader) 대신 사용되는 수동 구현체입니다.
public struct MLXHubDownloader: MLXLMCommon.Downloader {
private let upstream: HuggingFace.HubClient

public init(hubClient: HuggingFace.HubClient = HuggingFace.HubClient()) {
upstream = hubClient
}

public func download(
id: String,
revision: String?,
matching patterns: [String],
useLatest: Bool,
progressHandler: @Sendable @escaping (Foundation.Progress) -> Void
) async throws -> URL {
guard let repoID = HuggingFace.Repo.ID(rawValue: id) else {
throw HuggingFaceDownloaderError.invalidRepositoryID(id)
}
let revision = revision ?? "main"

return try await upstream.downloadSnapshot(
of: repoID,
revision: revision,
matching: patterns,
progressHandler: { @MainActor progress in
progressHandler(progress)
}
Comment thread
Kimyonhae marked this conversation as resolved.
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public actor MLXModelProvider: MLXModelDataSource {
let configuration = try matchModelConfiguration(model: model)
let path = try await resolve(
configuration: configuration,
from: #hubDownloader(),
from: MLXHubDownloader(),
useLatest: false,
progressHandler: progressHandler
)
Expand Down Expand Up @@ -121,7 +121,7 @@ public actor MLXModelProvider: MLXModelDataSource {
public nonisolated func loadModel() async throws(MLXModelDataSourceError) -> ModelContext {
do {
let from: URL = try await getDownloadPath()
let context = try await LLMModelFactory.shared.load(from: from, using: #huggingFaceTokenizerLoader())
let context = try await LLMModelFactory.shared.load(from: from, using: MLXTokenizerLoader())
AppLogger.info("MLX model loaded: \(context)")
return context
} catch is CancellationError {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import Foundation
import MLXLMCommon
import Tokenizers

/// Hugging Face 기반 토크나이저를 로드하기 위한 TokenizerLoader 구현체.
/// 컴파일 타임 매크로(#huggingFaceTokenizerLoader) 대신 사용되는 수동 구현체입니다.
public struct MLXTokenizerLoader: MLXLMCommon.TokenizerLoader {
public init() {}

public func load(from directory: URL) async throws -> any MLXLMCommon.Tokenizer {
let upstream = try await Tokenizers.AutoTokenizer.from(modelFolder: directory)
return TokenizerBridge(upstream)
}
}

private struct TokenizerBridge: MLXLMCommon.Tokenizer {
private let upstream: any Tokenizers.Tokenizer

init(_ upstream: any Tokenizers.Tokenizer) {
self.upstream = upstream
}

func encode(text: String, addSpecialTokens: Bool) -> [Int] {
upstream.encode(text: text, addSpecialTokens: addSpecialTokens)
}

func decode(tokenIds: [Int], skipSpecialTokens: Bool) -> String {
upstream.decode(tokens: tokenIds, skipSpecialTokens: skipSpecialTokens)
}

func convertTokenToId(_ token: String) -> Int? {
upstream.convertTokenToId(token)
}

func convertIdToToken(_ id: Int) -> String? {
upstream.convertIdToToken(id)
}

var bosToken: String? {
upstream.bosToken
}

var eosToken: String? {
upstream.eosToken
}

var unknownToken: String? {
upstream.unknownToken
}

func applyChatTemplate(
messages: [[String: any Sendable]],
tools: [[String: any Sendable]]?,
additionalContext: [String: any Sendable]?
) throws -> [Int] {
do {
return try upstream.applyChatTemplate(
messages: messages, tools: tools, additionalContext: additionalContext
)
} catch Tokenizers.TokenizerError.missingChatTemplate {
throw MLXLMCommon.TokenizerError.missingChatTemplate
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,6 @@ public actor WhisperKitProvider: WhisperDataSource {
progressCallback: progressHandler
)

// 다운로드 복귀 직후 태스크 취소 상태 감지 (레이스 컨디션 봉쇄)
if Task.isCancelled {
AppLogger.info("WhisperKit 다운로드 완료 복귀 후 취소 상태 감지 - 즉각 강제 소거 및 에러 방출")
try? storageService.delete(fileURL: path)
throw CancellationError()
}

modelDirectory = path
AppLogger.info("WhisperKit 모델 위치 : \(modelDirectory?.path() ?? "없음")")
}
Expand Down Expand Up @@ -143,14 +136,7 @@ public actor WhisperKitProvider: WhisperDataSource {
let relativePath = "huggingface/models/argmaxinc/whisperkit-coreml/\(recommendedModel)"
let defaultPath = storageService.absoluteURL(for: relativePath)

let configPath = "\(relativePath)/config.json"
let vocabPath = "\(relativePath)/vocab.json"

// 디렉토리 존재뿐만 아니라 핵심 구성 파일(config.json, vocab.json)의 완결성 검사를 수행하여 부분 다운로드 및 비정상 종료된 찌꺼기를 필터링합니다.
if storageService.exists(relativePath: relativePath),
storageService.exists(relativePath: configPath),
storageService.exists(relativePath: vocabPath)
{
if storageService.exists(relativePath: relativePath) {
modelDirectory = defaultPath
self.recommendedModel = recommendedModel
AppLogger.info("whisper 저장 위치 (디스크 감지) : \(defaultPath)")
Comment thread
Kimyonhae marked this conversation as resolved.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@ public final class DefaultAvailableModelSupportRepository: AvailableModelSupport
/// 현재 사용자의 On-Device LLM 모두 fetch 합니다.
public func fetchSupportModels() async -> [ChaGokModelState] {
let models: [ChaGokModel] = ChaGokModel.models
var whisperStatus = OnDeviceStatus(storage: .notDownloaded, runtime: .unloaded)
var mlxStatus = OnDeviceStatus(storage: .notDownloaded, runtime: .unloaded)
var whisperStatus = OnDeviceStatus(storage: .notDownloaded)
var mlxStatus = OnDeviceStatus(storage: .notDownloaded)

do {
_ = try await whisperProvider.getDownloadPath()
whisperStatus = OnDeviceStatus(storage: .downloaded, runtime: .unloaded)
whisperStatus = OnDeviceStatus(storage: .downloaded)
} catch {
AppLogger.info("Whisper 모델 다운로드 경로 없음: \(error.localizedDescription)")
}

do {
_ = try await mlxProvider.getDownloadPath()
mlxStatus = OnDeviceStatus(storage: .downloaded, runtime: .unloaded)
mlxStatus = OnDeviceStatus(storage: .downloaded)
} catch {
AppLogger.info("MLX 모델 다운로드 경로 없음: \(error.localizedDescription)")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ public final class DefaultMlxOnDeviceRepository: OnDeviceRepository {
public func delete() async throws(DeleteOnDeviceRepositoryError) -> OnDeviceStatus {
do {
try await provider.delete()
return OnDeviceStatus(storage: .notDownloaded, runtime: .unloaded)
return OnDeviceStatus(storage: .notDownloaded)
} catch {
AppLogger.error(error)
switch error {
case .cancelled:
throw .cancelled
case .notFound, .downloadFailed, .deleteFailed:
return OnDeviceStatus(storage: .notDownloaded, runtime: .unloaded)
return OnDeviceStatus(storage: .notDownloaded)
case .networkFailed:
throw .deleteMLXFailed
case .unknown(let underlying):
Expand All @@ -67,9 +67,9 @@ public final class DefaultMlxOnDeviceRepository: OnDeviceRepository {
public func checkStatus() async -> OnDeviceStatus {
do {
_ = try await provider.getDownloadPath()
return OnDeviceStatus(storage: .downloaded, runtime: .unloaded)
return OnDeviceStatus(storage: .downloaded)
} catch {
return OnDeviceStatus(storage: .notDownloaded, runtime: .unloaded)
return OnDeviceStatus(storage: .notDownloaded)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public struct DefaultWhisperOnDeviceRepository: OnDeviceRepository {
public func delete() async throws(DeleteOnDeviceRepositoryError) -> OnDeviceStatus {
do {
try await provider.delete()
return OnDeviceStatus(storage: .notDownloaded, runtime: .unloaded)
return OnDeviceStatus(storage: .notDownloaded)
} catch {
AppLogger.error(error)
throw .deleteWhisperFailed
Expand All @@ -51,9 +51,9 @@ public struct DefaultWhisperOnDeviceRepository: OnDeviceRepository {
public func checkStatus() async -> OnDeviceStatus {
do {
_ = try await provider.getDownloadPath()
return OnDeviceStatus(storage: .downloaded, runtime: .unloaded)
return OnDeviceStatus(storage: .downloaded)
} catch {
return OnDeviceStatus(storage: .notDownloaded, runtime: .unloaded)
return OnDeviceStatus(storage: .notDownloaded)
}
}
}
2 changes: 1 addition & 1 deletion Domain/Sources/Entities/ChaGokModelSupport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public struct ChaGokModelState: Hashable, Sendable {
title: String,
subTitle: String,
model: ChaGokModel,
status: OnDeviceStatus = .init(storage: .notDownloaded, runtime: .unloaded)
status: OnDeviceStatus = .init(storage: .notDownloaded)
) {
self.title = title
self.subTitle = subTitle
Expand Down
16 changes: 1 addition & 15 deletions Domain/Sources/Entities/OnDeviceStatus.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,11 @@ import Foundation
public struct OnDeviceStatus: Hashable, Sendable {
/// 디스크 또는 캐시 상의 저장 상태입니다.
public var storage: StorageState
/// 메모리 상에서 모델이 준비된 상태입니다.
public var runtime: RuntimeState

public init(
storage: StorageState = .notDownloaded,
runtime: RuntimeState = .unloaded
storage: StorageState = .notDownloaded
) {
self.storage = storage
self.runtime = runtime
}

/// 다운로드 및 삭제처럼, 모델 파일의 보관 상태를 나타냅니다.
Expand All @@ -28,14 +24,4 @@ public struct OnDeviceStatus: Hashable, Sendable {
/// 저장 단계에서 실패한 상태입니다.
case failed
}

/// 로드처럼, 메모리 적재 여부를 나타냅니다.
public enum RuntimeState: Sendable, Hashable {
/// 메모리에 올려지지 않은 상태입니다.
case unloaded
/// 메모리 로드가 진행 중인 상태입니다.
case loading
/// 메모리에 적재된 상태입니다.
case loaded
}
}
Loading
Loading