Skip to content
73 changes: 56 additions & 17 deletions Sources/SwiftyXPC/XPCConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public class XPCConnection: @unchecked Sendable {
/// An XPC message was missing its request and/or response body.
case missingMessageBody
/// Received an unhandled XPC message.
case unexpectedMessage
case unexpectedMessage(String)
/// A message contained data of the wrong type.
case typeMismatch(expected: XPCType, actual: XPCType)
/// Only used on macOS versions prior to 12.0.
Expand All @@ -30,6 +30,11 @@ public class XPCConnection: @unchecked Sendable {
static let error = "com.charlessoft.SwiftyXPC.XPCEventHandler.Error"
}

private struct StateMessages {
static let activate = "com.charlessoft.SwiftyXPC.XPCConnectionState.Activate"
static let deactivate = "com.charlessoft.SwiftyXPC.XPCConnectionState.Deactivate"
}

/// Represents the various types of connection that can be created.
public enum ConnectionType {
/// Connect to an embedded XPC service inside the current application’s bundle. Pass the XPC service’s bundle ID as the `bundleID` parameter.
Expand All @@ -41,7 +46,10 @@ public class XPCConnection: @unchecked Sendable {
}

/// A handler that will be called if a communication error occurs.
public typealias ErrorHandler = (XPCConnection, Swift.Error) -> Void
public typealias ErrorHandler = (XPCConnection, Swift.Error) async -> Void

/// A handler that will be called when the connection is cancelled.
public typealias CancelHandler = () async -> Void

internal class MessageHandler {
typealias RawHandler = ((XPCConnection, xpc_object_t) async throws -> xpc_object_t)
Expand Down Expand Up @@ -126,10 +134,21 @@ public class XPCConnection: @unchecked Sendable {
/// A handler that will be called if a communication error occurs.
public var errorHandler: ErrorHandler? = nil

internal var customEventHandler: xpc_handler_t? = nil
/// A handler that will be called when the connection is cancelled.
public var cancelHandler: CancelHandler? = nil

internal var customEventHandler: ((xpc_object_t) async -> Void)? = nil

internal func getMessageHandler(forName name: String) -> MessageHandler.RawHandler? {
self.messageHandlers[name]?.closure
switch name {
case StateMessages.activate: return { _, _ in try XPCEncoder().encode(XPCNull.shared) }
case StateMessages.deactivate:
return { [cancelHandler] _, _ in
await cancelHandler?()
return try XPCEncoder().encode(XPCNull.shared)
}
default: return self.messageHandlers[name]?.closure
}
}

/// Set a message handler for an incoming message, identified by the `name` parameter, without taking any arguments or returning any value.
Expand Down Expand Up @@ -232,6 +251,12 @@ public class XPCConnection: @unchecked Sendable {
/// Activate the connection.
///
/// Connections start in an inactive state, so you must call `activate()` on a connection before it will send or receive any messages.

public func activate() async throws {
xpc_connection_activate(self.connection)
try await sendMessage(name: StateMessages.activate)
}

public func activate() {
xpc_connection_activate(self.connection)
}
Expand Down Expand Up @@ -263,6 +288,11 @@ public class XPCConnection: @unchecked Sendable {
xpc_connection_cancel(self.connection)
}

public func cancel() async throws {
try await sendMessage(name: StateMessages.deactivate)
xpc_connection_cancel(self.connection)
}

internal func makeEndpoint() -> XPCEndpoint {
XPCEndpoint(connection: self.connection)
}
Expand Down Expand Up @@ -295,7 +325,7 @@ public class XPCConnection: @unchecked Sendable {
/// - Returns: The value returned by the receiving connection's helper function.
///
/// - Throws: Throws an error if the receiving connection throws an error in its handler, or if a communication error occurs.
public func sendMessage<Response: Codable>(name: String) async throws -> Response {
public func sendMessage<Response: Codable & Sendable>(name: String) async throws -> Response {
try await self.sendMessage(name: name, request: XPCNull.shared)
}

Expand All @@ -309,7 +339,7 @@ public class XPCConnection: @unchecked Sendable {
///
/// - Throws: Throws an error the `request` parameter does not match the type specified by the receiving connection’s handler function,
/// if the receiving connection throws an error in its handler, or if a communication error occurs.
public func sendMessage<Request: Codable, Response: Codable>(name: String, request: Request) async throws -> Response {
public func sendMessage<Response: Codable & Sendable>(name: String, request: some Codable) async throws -> Response {
let body = try XPCEncoder().encode(request)

return try await withCheckedThrowingContinuation { continuation in
Expand All @@ -334,11 +364,14 @@ public class XPCConnection: @unchecked Sendable {
throw Error.missingMessageBody
}

if Response.self == XPCNull.self {
continuation.resume(returning: XPCNull() as! Response)
} else {
continuation.resume(returning: try XPCDecoder().decode(type: Response.self, from: body))
}
let response: Response =
if Response.self == XPCNull.self {
XPCNull() as! Response
} else {
try XPCDecoder().decode(type: Response.self, from: body)
}

continuation.resume(returning: response)
} catch {
continuation.resume(throwing: error)
}
Expand Down Expand Up @@ -426,13 +459,17 @@ public class XPCConnection: @unchecked Sendable {
do {
try self.checkCallerCredentials(event: event)
} catch {
self.errorHandler?(self, error)
Task {
await self.errorHandler?(self, error)
}
return
}
}

if let customEventHandler = self.customEventHandler {
customEventHandler(event)
Task {
await customEventHandler(event)
}
return
}

Expand All @@ -450,7 +487,9 @@ public class XPCConnection: @unchecked Sendable {
throw Error.typeMismatch(expected: .dictionary, actual: event.type)
}
} catch {
self.errorHandler?(self, error)
Task {
await self.errorHandler?(self, error)
}
return
}
}
Expand Down Expand Up @@ -519,12 +558,12 @@ public class XPCConnection: @unchecked Sendable {
}

guard let _messageHandler = self.getMessageHandler(forName: name) else {
throw Error.unexpectedMessage
throw Error.unexpectedMessage(name)
}

messageHandler = _messageHandler
} catch {
self.errorHandler?(self, error)
await self.errorHandler?(self, error)
return
}

Expand All @@ -540,7 +579,7 @@ public class XPCConnection: @unchecked Sendable {
do {
try self.sendOnewayRawMessage(name: nil, body: response, key: MessageKeys.body, asReplyTo: event)
} catch {
self.errorHandler?(self, error)
await self.errorHandler?(self, error)
}
}
}
Expand Down
55 changes: 47 additions & 8 deletions Sources/SwiftyXPC/XPCErrorRegistry.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// Created by Charles Srstka on 12/19/21.
//

import Synchronization
import XPC

/// A registry which facilitates decoding error types that are sent over an XPC connection.
Expand Down Expand Up @@ -42,26 +43,60 @@ import XPC
/// } catch {
/// print("got some other error")
/// }
public class XPCErrorRegistry {
public final class XPCErrorRegistry: Sendable {
/// The shared `XPCErrorRegistry` instance.
public static let shared = XPCErrorRegistry()

private var errorDomainMap: [String: (Error & Codable).Type] = [
String(reflecting: XPCError.self): XPCError.self,
String(reflecting: XPCConnection.Error.self): XPCConnection.Error.self,
]
@available(macOS 15.0, macCatalyst 18.0, *)
private final class MutexWrapper: Sendable {
let mutex: Mutex<[String: (Error & Codable).Type]>
init(dict: [String: (Error & Codable).Type]) { self.mutex = Mutex(dict) }
}

private final class LegacyWrapper: @unchecked Sendable {
let sema = DispatchSemaphore(value: 1)
var dict: [String: (Error & Codable).Type]
init(dict: [String: (Error & Codable).Type]) { self.dict = dict }
}

private let errorDomainMapWrapper: any Sendable = {
let errorDomainMap: [String: (Error & Codable).Type] = [
String(reflecting: XPCError.self): XPCError.self,
String(reflecting: XPCConnection.Error.self): XPCConnection.Error.self,
]

if #available(macOS 15.0, macCatalyst 18.0, *) {
return MutexWrapper(dict: errorDomainMap)
} else {
return LegacyWrapper(dict: errorDomainMap)
}
}()

private func withLock<T>(closure: (inout [String: (Error & Codable).Type]) throws -> T) rethrows -> T {
if #available(macOS 15.0, macCatalyst 18.0, *) {
return try (self.errorDomainMapWrapper as! MutexWrapper).mutex.withLock { try closure(&$0) }
} else {
let wrapper = self.errorDomainMapWrapper as! LegacyWrapper
wrapper.sema.wait()
defer { wrapper.sema.signal() }

return try closure(&wrapper.dict)
}
}

/// Register an error type.
///
/// - Parameters:
/// - domain: An `NSError`-style domain string to associate with this error type. In most cases, you will just pass `nil` for this parameter, in which case the default value of `String(reflecting: errorType)` will be used instead.
/// - errorType: An error type to register. This type must conform to `Codable`.
public func registerDomain(_ domain: String? = nil, forErrorType errorType: (Error & Codable).Type) {
errorDomainMap[domain ?? String(reflecting: errorType)] = errorType
self.withLock { $0[domain ?? String(reflecting: errorType)] = errorType }
}

internal func encodeError(_ error: Error, domain: String? = nil) throws -> xpc_object_t {
try XPCEncoder().encode(BoxedError(error: error, domain: domain))
try self.withLock { _ in
try XPCEncoder().encode(BoxedError(error: error, domain: domain))
}
}

internal func decodeError(_ error: xpc_object_t) throws -> Error {
Expand All @@ -70,6 +105,10 @@ public class XPCErrorRegistry {
return boxedError.encodedError ?? boxedError
}

internal func errorType(forDomain domain: String) -> (any (Error & Codable).Type)? {
self.withLock { $0[domain] }
}

/// An error type representing errors for which we have an `NSError`-style domain and code, but do not know the exact error class.
///
/// To avoid requiring Foundation, this type does not formally adopt the `CustomNSError` protocol, but implements methods which
Expand Down Expand Up @@ -155,7 +194,7 @@ public class XPCErrorRegistry {
self.errorDomain = try container.decode(String.self, forKey: .domain)
let code = try container.decode(Int.self, forKey: .code)

if let codableType = XPCErrorRegistry.shared.errorDomainMap[self.errorDomain],
if let codableType = XPCErrorRegistry.shared.errorType(forDomain: self.errorDomain),
let codableError = try codableType.decodeIfPresent(from: container, key: .encodedError)
{
self.storage = .codable(codableError)
Expand Down
23 changes: 18 additions & 5 deletions Sources/SwiftyXPC/XPCListener.swift
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,14 @@ public final class XPCListener {
}
}

/// A handler that will be called when a new connection activates.
public typealias ActivatedConnectionHandler = (XPCConnection) async -> Void
public var activatedConnectionHandler: ActivatedConnectionHandler? = nil

/// A handler that will be called when a connection cancels.
public typealias CanceledConnectionHandler = (XPCConnection) async -> Void
public var canceledConnectionHandler: CanceledConnectionHandler? = nil

/// Create a new `XPCListener`.
///
/// - Parameters:
Expand Down Expand Up @@ -219,10 +227,15 @@ public final class XPCListener {

newConnection.messageHandlers = self?.messageHandlers ?? [:]
newConnection.errorHandler = self?.errorHandler

newConnection.activate()
newConnection.cancelHandler = {
await self?.canceledConnectionHandler?(newConnection)
}
await self?.activatedConnectionHandler?(newConnection)
try await newConnection.activate()
} catch {
self?.errorHandler?(connection, error)
Task {
await self?.errorHandler?(connection, error)
}
}
}
}
Expand All @@ -232,12 +245,12 @@ public final class XPCListener {
///
/// After this call, any messages that have not yet been sent will be discarded, and the connection will be unwound.
/// If there are messages that are awaiting replies, they will receive the `XPCError.connectionInvalid` error.
public func cancel() {
public func cancel() async throws {
switch self.backing {
case .xpcMain:
fatalError("XPC service listener cannot be cancelled")
case .connection(let connection, _):
connection.cancel()
try await connection.cancel()
}
}

Expand Down
18 changes: 15 additions & 3 deletions Sources/TestHelper/TestHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,31 @@ import TestShared

@main
@available(macOS 13.0, *)
class XPCService {
final class XPCService: Sendable {
static func main() {
do {
let xpcService = XPCService()

let listener = try XPCListener(type: .machService(name: helperID), codeSigningRequirement: nil)

var connectionCount: Int = 0
listener.activatedConnectionHandler = { _ in
connectionCount += 1
}
listener.canceledConnectionHandler = { _ in
connectionCount -= 1
}
listener.setMessageHandler(name: CommandSet.reportIDs, handler: xpcService.reportIDs)
listener.setMessageHandler(name: CommandSet.capitalizeString, handler: xpcService.capitalizeString)
listener.setMessageHandler(name: CommandSet.multiplyBy5, handler: xpcService.multiplyBy5)
listener.setMessageHandler(name: CommandSet.transportData, handler: xpcService.transportData)
listener.setMessageHandler(name: CommandSet.tellAJoke, handler: xpcService.tellAJoke)
listener.setMessageHandler(name: CommandSet.pauseOneSecond, handler: xpcService.pauseOneSecond)
listener.setMessageHandler(
name: CommandSet.countConnections,
handler: { _ in
connectionCount
})

listener.activate()
dispatchMain()
Expand Down Expand Up @@ -59,7 +71,7 @@ class XPCService {
"Noonien Soong".data(using: .utf8)!,
"Arik Soong".data(using: .utf8)!,
"Altan Soong".data(using: .utf8)!,
"Adam Soong".data(using: .utf8)!
"Adam Soong".data(using: .utf8)!,
]
)
}
Expand All @@ -70,7 +82,7 @@ class XPCService {
codeSigningRequirement: nil
)

remoteConnection.activate()
try await remoteConnection.activate()

let opening: String = try await remoteConnection.sendMessage(name: JokeMessage.askForJoke, request: "Tell me a joke")

Expand Down
1 change: 1 addition & 0 deletions Sources/TestShared/CommandSet.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ public struct CommandSet {
public static let transportData = "com.charlessoft.SwiftyXPC.Tests.TransportData"
public static let tellAJoke = "com.charlessoft.SwiftyXPC.Tests.TellAJoke"
public static let pauseOneSecond = "com.charlessoft.SwiftyXPC.Tests.PauseOneSecond"
public static let countConnections = "com.charlessoft.SwiftyXPC.Tests.CountConnections"
}
2 changes: 1 addition & 1 deletion Sources/TestShared/DataInfo.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import Foundation

public struct DataInfo: Codable {
public struct DataInfo: Codable, Sendable {
public struct DataError: LocalizedError, Codable {
public let failureReason: String?
public init(failureReason: String) { self.failureReason = failureReason }
Expand Down
2 changes: 1 addition & 1 deletion Sources/TestShared/ProcessIDs.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import SwiftyXPC
import System

// swift-format-ignore: AllPublicDeclarationsHaveDocumentation
public struct ProcessIDs: Codable {
public struct ProcessIDs: Codable, Sendable {
public let pid: pid_t
public let effectiveUID: uid_t
public let effectiveGID: gid_t
Expand Down
Loading