diff --git a/Sources/SwiftyXPC/XPCConnection.swift b/Sources/SwiftyXPC/XPCConnection.swift index 11b5878..c1bf1ed 100644 --- a/Sources/SwiftyXPC/XPCConnection.swift +++ b/Sources/SwiftyXPC/XPCConnection.swift @@ -17,7 +17,7 @@ public class XPCConnection: @unchecked Sendable { /// An XPC message was missing its request and/or response body. case missingMessageBody /// Received an unhandled XPC message. - case unexpectedMessage + case unexpectedMessage(String) /// A message contained data of the wrong type. case typeMismatch(expected: XPCType, actual: XPCType) /// Only used on macOS versions prior to 12.0. @@ -30,6 +30,11 @@ public class XPCConnection: @unchecked Sendable { static let error = "com.charlessoft.SwiftyXPC.XPCEventHandler.Error" } + private struct StateMessages { + static let activate = "com.charlessoft.SwiftyXPC.XPCConnectionState.Activate" + static let deactivate = "com.charlessoft.SwiftyXPC.XPCConnectionState.Deactivate" + } + /// Represents the various types of connection that can be created. public enum ConnectionType { /// Connect to an embedded XPC service inside the current application’s bundle. Pass the XPC service’s bundle ID as the `bundleID` parameter. @@ -41,7 +46,10 @@ public class XPCConnection: @unchecked Sendable { } /// A handler that will be called if a communication error occurs. - public typealias ErrorHandler = (XPCConnection, Swift.Error) -> Void + public typealias ErrorHandler = (XPCConnection, Swift.Error) async -> Void + + /// A handler that will be called when the connection is cancelled. + public typealias CancelHandler = () async -> Void internal class MessageHandler { typealias RawHandler = ((XPCConnection, xpc_object_t) async throws -> xpc_object_t) @@ -126,10 +134,21 @@ public class XPCConnection: @unchecked Sendable { /// A handler that will be called if a communication error occurs. public var errorHandler: ErrorHandler? = nil - internal var customEventHandler: xpc_handler_t? = nil + /// A handler that will be called when the connection is cancelled. + public var cancelHandler: CancelHandler? = nil + + internal var customEventHandler: ((xpc_object_t) async -> Void)? = nil internal func getMessageHandler(forName name: String) -> MessageHandler.RawHandler? { - self.messageHandlers[name]?.closure + switch name { + case StateMessages.activate: return { _, _ in try XPCEncoder().encode(XPCNull.shared) } + case StateMessages.deactivate: + return { [cancelHandler] _, _ in + await cancelHandler?() + return try XPCEncoder().encode(XPCNull.shared) + } + default: return self.messageHandlers[name]?.closure + } } /// Set a message handler for an incoming message, identified by the `name` parameter, without taking any arguments or returning any value. @@ -232,6 +251,12 @@ public class XPCConnection: @unchecked Sendable { /// Activate the connection. /// /// Connections start in an inactive state, so you must call `activate()` on a connection before it will send or receive any messages. + + public func activate() async throws { + xpc_connection_activate(self.connection) + try await sendMessage(name: StateMessages.activate) + } + public func activate() { xpc_connection_activate(self.connection) } @@ -263,6 +288,11 @@ public class XPCConnection: @unchecked Sendable { xpc_connection_cancel(self.connection) } + public func cancel() async throws { + try await sendMessage(name: StateMessages.deactivate) + xpc_connection_cancel(self.connection) + } + internal func makeEndpoint() -> XPCEndpoint { XPCEndpoint(connection: self.connection) } @@ -295,7 +325,7 @@ public class XPCConnection: @unchecked Sendable { /// - Returns: The value returned by the receiving connection's helper function. /// /// - Throws: Throws an error if the receiving connection throws an error in its handler, or if a communication error occurs. - public func sendMessage(name: String) async throws -> Response { + public func sendMessage(name: String) async throws -> Response { try await self.sendMessage(name: name, request: XPCNull.shared) } @@ -309,7 +339,7 @@ public class XPCConnection: @unchecked Sendable { /// /// - Throws: Throws an error the `request` parameter does not match the type specified by the receiving connection’s handler function, /// if the receiving connection throws an error in its handler, or if a communication error occurs. - public func sendMessage(name: String, request: Request) async throws -> Response { + public func sendMessage(name: String, request: some Codable) async throws -> Response { let body = try XPCEncoder().encode(request) return try await withCheckedThrowingContinuation { continuation in @@ -334,11 +364,14 @@ public class XPCConnection: @unchecked Sendable { throw Error.missingMessageBody } - if Response.self == XPCNull.self { - continuation.resume(returning: XPCNull() as! Response) - } else { - continuation.resume(returning: try XPCDecoder().decode(type: Response.self, from: body)) - } + let response: Response = + if Response.self == XPCNull.self { + XPCNull() as! Response + } else { + try XPCDecoder().decode(type: Response.self, from: body) + } + + continuation.resume(returning: response) } catch { continuation.resume(throwing: error) } @@ -426,13 +459,17 @@ public class XPCConnection: @unchecked Sendable { do { try self.checkCallerCredentials(event: event) } catch { - self.errorHandler?(self, error) + Task { + await self.errorHandler?(self, error) + } return } } if let customEventHandler = self.customEventHandler { - customEventHandler(event) + Task { + await customEventHandler(event) + } return } @@ -450,7 +487,9 @@ public class XPCConnection: @unchecked Sendable { throw Error.typeMismatch(expected: .dictionary, actual: event.type) } } catch { - self.errorHandler?(self, error) + Task { + await self.errorHandler?(self, error) + } return } } @@ -519,12 +558,12 @@ public class XPCConnection: @unchecked Sendable { } guard let _messageHandler = self.getMessageHandler(forName: name) else { - throw Error.unexpectedMessage + throw Error.unexpectedMessage(name) } messageHandler = _messageHandler } catch { - self.errorHandler?(self, error) + await self.errorHandler?(self, error) return } @@ -540,7 +579,7 @@ public class XPCConnection: @unchecked Sendable { do { try self.sendOnewayRawMessage(name: nil, body: response, key: MessageKeys.body, asReplyTo: event) } catch { - self.errorHandler?(self, error) + await self.errorHandler?(self, error) } } } diff --git a/Sources/SwiftyXPC/XPCErrorRegistry.swift b/Sources/SwiftyXPC/XPCErrorRegistry.swift index b1219b7..721ce7a 100644 --- a/Sources/SwiftyXPC/XPCErrorRegistry.swift +++ b/Sources/SwiftyXPC/XPCErrorRegistry.swift @@ -5,6 +5,7 @@ // Created by Charles Srstka on 12/19/21. // +import Synchronization import XPC /// A registry which facilitates decoding error types that are sent over an XPC connection. @@ -42,14 +43,46 @@ import XPC /// } catch { /// print("got some other error") /// } -public class XPCErrorRegistry { +public final class XPCErrorRegistry: Sendable { /// The shared `XPCErrorRegistry` instance. public static let shared = XPCErrorRegistry() - private var errorDomainMap: [String: (Error & Codable).Type] = [ - String(reflecting: XPCError.self): XPCError.self, - String(reflecting: XPCConnection.Error.self): XPCConnection.Error.self, - ] + @available(macOS 15.0, macCatalyst 18.0, *) + private final class MutexWrapper: Sendable { + let mutex: Mutex<[String: (Error & Codable).Type]> + init(dict: [String: (Error & Codable).Type]) { self.mutex = Mutex(dict) } + } + + private final class LegacyWrapper: @unchecked Sendable { + let sema = DispatchSemaphore(value: 1) + var dict: [String: (Error & Codable).Type] + init(dict: [String: (Error & Codable).Type]) { self.dict = dict } + } + + private let errorDomainMapWrapper: any Sendable = { + let errorDomainMap: [String: (Error & Codable).Type] = [ + String(reflecting: XPCError.self): XPCError.self, + String(reflecting: XPCConnection.Error.self): XPCConnection.Error.self, + ] + + if #available(macOS 15.0, macCatalyst 18.0, *) { + return MutexWrapper(dict: errorDomainMap) + } else { + return LegacyWrapper(dict: errorDomainMap) + } + }() + + private func withLock(closure: (inout [String: (Error & Codable).Type]) throws -> T) rethrows -> T { + if #available(macOS 15.0, macCatalyst 18.0, *) { + return try (self.errorDomainMapWrapper as! MutexWrapper).mutex.withLock { try closure(&$0) } + } else { + let wrapper = self.errorDomainMapWrapper as! LegacyWrapper + wrapper.sema.wait() + defer { wrapper.sema.signal() } + + return try closure(&wrapper.dict) + } + } /// Register an error type. /// @@ -57,11 +90,13 @@ public class XPCErrorRegistry { /// - domain: An `NSError`-style domain string to associate with this error type. In most cases, you will just pass `nil` for this parameter, in which case the default value of `String(reflecting: errorType)` will be used instead. /// - errorType: An error type to register. This type must conform to `Codable`. public func registerDomain(_ domain: String? = nil, forErrorType errorType: (Error & Codable).Type) { - errorDomainMap[domain ?? String(reflecting: errorType)] = errorType + self.withLock { $0[domain ?? String(reflecting: errorType)] = errorType } } internal func encodeError(_ error: Error, domain: String? = nil) throws -> xpc_object_t { - try XPCEncoder().encode(BoxedError(error: error, domain: domain)) + try self.withLock { _ in + try XPCEncoder().encode(BoxedError(error: error, domain: domain)) + } } internal func decodeError(_ error: xpc_object_t) throws -> Error { @@ -70,6 +105,10 @@ public class XPCErrorRegistry { return boxedError.encodedError ?? boxedError } + internal func errorType(forDomain domain: String) -> (any (Error & Codable).Type)? { + self.withLock { $0[domain] } + } + /// An error type representing errors for which we have an `NSError`-style domain and code, but do not know the exact error class. /// /// To avoid requiring Foundation, this type does not formally adopt the `CustomNSError` protocol, but implements methods which @@ -155,7 +194,7 @@ public class XPCErrorRegistry { self.errorDomain = try container.decode(String.self, forKey: .domain) let code = try container.decode(Int.self, forKey: .code) - if let codableType = XPCErrorRegistry.shared.errorDomainMap[self.errorDomain], + if let codableType = XPCErrorRegistry.shared.errorType(forDomain: self.errorDomain), let codableError = try codableType.decodeIfPresent(from: container, key: .encodedError) { self.storage = .codable(codableError) diff --git a/Sources/SwiftyXPC/XPCListener.swift b/Sources/SwiftyXPC/XPCListener.swift index 7f106f1..6fbed7b 100644 --- a/Sources/SwiftyXPC/XPCListener.swift +++ b/Sources/SwiftyXPC/XPCListener.swift @@ -170,6 +170,14 @@ public final class XPCListener { } } + /// A handler that will be called when a new connection activates. + public typealias ActivatedConnectionHandler = (XPCConnection) async -> Void + public var activatedConnectionHandler: ActivatedConnectionHandler? = nil + + /// A handler that will be called when a connection cancels. + public typealias CanceledConnectionHandler = (XPCConnection) async -> Void + public var canceledConnectionHandler: CanceledConnectionHandler? = nil + /// Create a new `XPCListener`. /// /// - Parameters: @@ -219,10 +227,15 @@ public final class XPCListener { newConnection.messageHandlers = self?.messageHandlers ?? [:] newConnection.errorHandler = self?.errorHandler - - newConnection.activate() + newConnection.cancelHandler = { + await self?.canceledConnectionHandler?(newConnection) + } + await self?.activatedConnectionHandler?(newConnection) + try await newConnection.activate() } catch { - self?.errorHandler?(connection, error) + Task { + await self?.errorHandler?(connection, error) + } } } } @@ -232,12 +245,12 @@ public final class XPCListener { /// /// After this call, any messages that have not yet been sent will be discarded, and the connection will be unwound. /// If there are messages that are awaiting replies, they will receive the `XPCError.connectionInvalid` error. - public func cancel() { + public func cancel() async throws { switch self.backing { case .xpcMain: fatalError("XPC service listener cannot be cancelled") case .connection(let connection, _): - connection.cancel() + try await connection.cancel() } } diff --git a/Sources/TestHelper/TestHelper.swift b/Sources/TestHelper/TestHelper.swift index 064909f..6343f78 100644 --- a/Sources/TestHelper/TestHelper.swift +++ b/Sources/TestHelper/TestHelper.swift @@ -12,19 +12,31 @@ import TestShared @main @available(macOS 13.0, *) -class XPCService { +final class XPCService: Sendable { static func main() { do { let xpcService = XPCService() let listener = try XPCListener(type: .machService(name: helperID), codeSigningRequirement: nil) + var connectionCount: Int = 0 + listener.activatedConnectionHandler = { _ in + connectionCount += 1 + } + listener.canceledConnectionHandler = { _ in + connectionCount -= 1 + } listener.setMessageHandler(name: CommandSet.reportIDs, handler: xpcService.reportIDs) listener.setMessageHandler(name: CommandSet.capitalizeString, handler: xpcService.capitalizeString) listener.setMessageHandler(name: CommandSet.multiplyBy5, handler: xpcService.multiplyBy5) listener.setMessageHandler(name: CommandSet.transportData, handler: xpcService.transportData) listener.setMessageHandler(name: CommandSet.tellAJoke, handler: xpcService.tellAJoke) listener.setMessageHandler(name: CommandSet.pauseOneSecond, handler: xpcService.pauseOneSecond) + listener.setMessageHandler( + name: CommandSet.countConnections, + handler: { _ in + connectionCount + }) listener.activate() dispatchMain() @@ -59,7 +71,7 @@ class XPCService { "Noonien Soong".data(using: .utf8)!, "Arik Soong".data(using: .utf8)!, "Altan Soong".data(using: .utf8)!, - "Adam Soong".data(using: .utf8)! + "Adam Soong".data(using: .utf8)!, ] ) } @@ -70,7 +82,7 @@ class XPCService { codeSigningRequirement: nil ) - remoteConnection.activate() + try await remoteConnection.activate() let opening: String = try await remoteConnection.sendMessage(name: JokeMessage.askForJoke, request: "Tell me a joke") diff --git a/Sources/TestShared/CommandSet.swift b/Sources/TestShared/CommandSet.swift index 164e9b9..2593b29 100644 --- a/Sources/TestShared/CommandSet.swift +++ b/Sources/TestShared/CommandSet.swift @@ -13,4 +13,5 @@ public struct CommandSet { public static let transportData = "com.charlessoft.SwiftyXPC.Tests.TransportData" public static let tellAJoke = "com.charlessoft.SwiftyXPC.Tests.TellAJoke" public static let pauseOneSecond = "com.charlessoft.SwiftyXPC.Tests.PauseOneSecond" + public static let countConnections = "com.charlessoft.SwiftyXPC.Tests.CountConnections" } diff --git a/Sources/TestShared/DataInfo.swift b/Sources/TestShared/DataInfo.swift index d616fbd..c49c453 100644 --- a/Sources/TestShared/DataInfo.swift +++ b/Sources/TestShared/DataInfo.swift @@ -7,7 +7,7 @@ import Foundation -public struct DataInfo: Codable { +public struct DataInfo: Codable, Sendable { public struct DataError: LocalizedError, Codable { public let failureReason: String? public init(failureReason: String) { self.failureReason = failureReason } diff --git a/Sources/TestShared/ProcessIDs.swift b/Sources/TestShared/ProcessIDs.swift index 42b866a..9d0c9a0 100644 --- a/Sources/TestShared/ProcessIDs.swift +++ b/Sources/TestShared/ProcessIDs.swift @@ -10,7 +10,7 @@ import SwiftyXPC import System // swift-format-ignore: AllPublicDeclarationsHaveDocumentation -public struct ProcessIDs: Codable { +public struct ProcessIDs: Codable, Sendable { public let pid: pid_t public let effectiveUID: uid_t public let effectiveGID: gid_t diff --git a/Tests/SwiftyXPCTests/SwiftyXPCTests.swift b/Tests/SwiftyXPCTests/SwiftyXPCTests.swift index 8ffa6b4..b113372 100644 --- a/Tests/SwiftyXPCTests/SwiftyXPCTests.swift +++ b/Tests/SwiftyXPCTests/SwiftyXPCTests.swift @@ -18,7 +18,7 @@ final class SwiftyXPCTests: XCTestCase { } func testProcessIDs() async throws { - let conn = try self.openConnection() + let conn = try await self.openConnection() let ids: ProcessIDs = try await conn.sendMessage(name: CommandSet.reportIDs) @@ -29,17 +29,17 @@ final class SwiftyXPCTests: XCTestCase { } func testCodeSignatureVerification() async throws { - let goodConn = try self.openConnection(codeSigningRequirement: self.helperLauncher!.codeSigningRequirement) + let goodConn = try await self.openConnection(codeSigningRequirement: self.helperLauncher!.codeSigningRequirement) let response: String = try await goodConn.sendMessage(name: CommandSet.capitalizeString, request: "Testing 1 2 3") XCTAssertEqual(response, "TESTING 1 2 3") - let badConn = try self.openConnection(codeSigningRequirement: "identifier \"com.apple.true\" and anchor apple") let failsSignatureVerification = self.expectation( description: "Fails to send message because of code signature mismatch" ) do { + let badConn = try await self.openConnection(codeSigningRequirement: "identifier \"com.apple.true\" and anchor apple") try await badConn.sendMessage(name: CommandSet.capitalizeString, request: "Testing 1 2 3") } catch let error as XPCError { if case .unknown(let errorDesc) = error, errorDesc == "Peer Forbidden" { @@ -54,7 +54,7 @@ final class SwiftyXPCTests: XCTestCase { ) do { - _ = try self.openConnection(codeSigningRequirement: "") + _ = try await self.openConnection(codeSigningRequirement: "") } catch XPCError.invalidCodeSignatureRequirement { failsConnectionInitialization.fulfill() } @@ -63,7 +63,7 @@ final class SwiftyXPCTests: XCTestCase { } func testSimpleRequestAndResponse() async throws { - let conn = try self.openConnection() + let conn = try await self.openConnection() let stringResponse: String = try await conn.sendMessage(name: CommandSet.capitalizeString, request: "hi there") XCTAssertEqual(stringResponse, "HI THERE") @@ -73,7 +73,7 @@ final class SwiftyXPCTests: XCTestCase { } func testDataTransport() async throws { - let conn = try self.openConnection() + let conn = try await self.openConnection() let dataInfo: DataInfo = try await conn.sendMessage( name: CommandSet.transportData, @@ -101,7 +101,7 @@ final class SwiftyXPCTests: XCTestCase { } func testTwoWayCommunication() async throws { - let conn = try self.openConnection() + let conn = try await self.openConnection() let listener = try XPCListener(type: .anonymous, codeSigningRequirement: nil) @@ -147,15 +147,18 @@ final class SwiftyXPCTests: XCTestCase { } listener.activate() - - try await conn.sendMessage(name: CommandSet.tellAJoke, request: listener.endpoint) + do { + try await conn.sendMessage(name: CommandSet.tellAJoke, request: listener.endpoint) + } catch { + print("ERRRR", error) + } await self.fulfillment(of: expectations, timeout: 10.0, enforceOrder: true) } func testTwoWayCommunicationWithError() async throws { XPCErrorRegistry.shared.registerDomain(forErrorType: JokeMessage.NotAKnockKnockJoke.self) - let conn = try self.openConnection() + let conn = try await self.openConnection() let listener = try XPCListener(type: .anonymous, codeSigningRequirement: nil) @@ -190,7 +193,7 @@ final class SwiftyXPCTests: XCTestCase { } func testOnewayVsTwoWay() async throws { - let conn = try self.openConnection() + let conn = try await self.openConnection() var date = Date.now try await conn.sendMessage(name: CommandSet.pauseOneSecond) @@ -202,12 +205,12 @@ final class SwiftyXPCTests: XCTestCase { } func testCancelConnection() async throws { - let conn = try self.openConnection() + let conn = try await self.openConnection() let response: String = try await conn.sendMessage(name: CommandSet.capitalizeString, request: "will work") XCTAssertEqual(response, "WILL WORK") - conn.cancel() + try await conn.cancel() let err: Error? do { @@ -222,13 +225,30 @@ final class SwiftyXPCTests: XCTestCase { return } } - - private func openConnection(codeSigningRequirement: String? = nil) throws -> XPCConnection { + + func testSimpleConnectionCounting() async throws { + let conn = try await self.openConnection() + + var count: Int = try await conn.sendMessage(name: CommandSet.countConnections) + XCTAssertEqual(count, 1) + + let conn2 = try await self.openConnection() + + count = try await conn.sendMessage(name: CommandSet.countConnections) + XCTAssertEqual(count, 2) + + try await conn2.cancel() + + count = try await conn.sendMessage(name: CommandSet.countConnections) + XCTAssertEqual(count, 1) + } + + private func openConnection(codeSigningRequirement: String? = nil) async throws -> XPCConnection { let conn = try XPCConnection( type: .remoteMachService(serviceName: helperID, isPrivilegedHelperTool: false), codeSigningRequirement: codeSigningRequirement ?? self.helperLauncher?.codeSigningRequirement ) - conn.activate() + try await conn.activate() return conn }