diff --git a/Sources/ApolloTestSupport/MockWebSocket.swift b/Sources/ApolloTestSupport/MockWebSocket.swift index 20c7cedaa2..c92e0bb3ef 100644 --- a/Sources/ApolloTestSupport/MockWebSocket.swift +++ b/Sources/ApolloTestSupport/MockWebSocket.swift @@ -8,8 +8,10 @@ public class MockWebSocket: WebSocketClient { public var delegate: WebSocketClientDelegate? = nil public var isConnected: Bool = false - public required init(request: URLRequest) { + public required init(request: URLRequest, protocol: WebSocket.WSProtocol) { self.request = request + + self.request.setValue(`protocol`.description, forHTTPHeaderField: WebSocket.Constants.headerWSProtocolName) } open func reportDidConnect() { diff --git a/Sources/ApolloWebSocket/WebSocketTransport.swift b/Sources/ApolloWebSocket/WebSocketTransport.swift index 20932d84c3..d28d35303b 100644 --- a/Sources/ApolloWebSocket/WebSocketTransport.swift +++ b/Sources/ApolloWebSocket/WebSocketTransport.swift @@ -280,12 +280,14 @@ public class WebSocketTransport { autoPersistQuery: false) let identifier = operationMessageIdCreator.requestId() - var type: OperationMessage.Types = .start - if case WebSocket.WSProtocol.graphql_transport_ws.description = websocket.request.value(forHTTPHeaderField: WebSocket.Constants.headerWSProtocolName) { - type = .subscribe + let messageType: OperationMessage.Types + switch websocket.request.wsProtocol { + case .graphql_ws: messageType = .start + case .graphql_transport_ws: messageType = .subscribe + default: return nil } - guard let message = OperationMessage(payload: body, id: identifier, type: type).rawMessage else { + guard let message = OperationMessage(payload: body, id: identifier, type: messageType).rawMessage else { return nil } @@ -302,7 +304,13 @@ public class WebSocketTransport { } public func unsubscribe(_ subscriptionId: String) { - let str = OperationMessage(id: subscriptionId, type: .stop).rawMessage + let messageType: OperationMessage.Types + switch websocket.request.wsProtocol { + case .graphql_transport_ws: messageType = .complete + default: messageType = .stop + } + + let str = OperationMessage(id: subscriptionId, type: messageType).rawMessage processingQueue.async { if let str = str { @@ -359,6 +367,20 @@ public class WebSocketTransport { } } +extension URLRequest { + fileprivate var wsProtocol: WebSocket.WSProtocol? { + guard let header = value(forHTTPHeaderField: WebSocket.Constants.headerWSProtocolName) else { + return nil + } + + switch header { + case WebSocket.WSProtocol.graphql_transport_ws.description: return .graphql_transport_ws + case WebSocket.WSProtocol.graphql_ws.description: return .graphql_ws + default: return nil + } + } +} + // MARK: - NetworkTransport conformance extension WebSocketTransport: NetworkTransport { diff --git a/Tests/ApolloServerIntegrationTests/StarWarsSubscriptionTests.swift b/Tests/ApolloServerIntegrationTests/StarWarsSubscriptionTests.swift index a10f6b33a7..05669fd8cd 100644 --- a/Tests/ApolloServerIntegrationTests/StarWarsSubscriptionTests.swift +++ b/Tests/ApolloServerIntegrationTests/StarWarsSubscriptionTests.swift @@ -411,7 +411,8 @@ class StarWarsSubscriptionTests: XCTestCase { func testConcurrentConnectAndCloseConnection() { let webSocketTransport = WebSocketTransport( websocket: MockWebSocket( - request: URLRequest(url: TestServerURL.starWarsWebSocket.url) + request: URLRequest(url: TestServerURL.starWarsWebSocket.url), + protocol: .graphql_ws ), store: ApolloStore() ) diff --git a/Tests/ApolloTests/WebSocket/GraphqlTransportWsProtocolTests.swift b/Tests/ApolloTests/WebSocket/GraphqlTransportWsProtocolTests.swift index 8e1c97e9bc..7381c5c99d 100644 --- a/Tests/ApolloTests/WebSocket/GraphqlTransportWsProtocolTests.swift +++ b/Tests/ApolloTests/WebSocket/GraphqlTransportWsProtocolTests.swift @@ -16,6 +16,10 @@ class GraphqlTransportWsProtocolTests: WSProtocolTestsBase { return request } + private func buildWebSocket() { + buildWebSocket(protocol: .graphql_transport_ws) + } + // MARK: Initializer Tests func test__designatedInitializer__shouldSetRequestProtocolHeader() { @@ -123,7 +127,7 @@ class GraphqlTransportWsProtocolTests: WSProtocolTestsBase { } } - func test__messaging__givenSubscriptionCancel_shouldSendStop() { + func test__messaging__givenSubscriptionCancel_shouldSendComplete() { // given buildWebSocket() buildClient() @@ -136,7 +140,7 @@ class GraphqlTransportWsProtocolTests: WSProtocolTestsBase { waitUntil { done in self.mockWebSocketDelegate.didReceiveMessage = { message in // then - let expected = OperationMessage(id: "1", type: .stop).rawMessage! + let expected = OperationMessage(id: "1", type: .complete).rawMessage! if message == expected { done() } diff --git a/Tests/ApolloTests/WebSocket/GraphqlWsProtocolTests.swift b/Tests/ApolloTests/WebSocket/GraphqlWsProtocolTests.swift index f71ee63b79..54bb62510e 100644 --- a/Tests/ApolloTests/WebSocket/GraphqlWsProtocolTests.swift +++ b/Tests/ApolloTests/WebSocket/GraphqlWsProtocolTests.swift @@ -16,6 +16,10 @@ class GraphqlWsProtocolTests: WSProtocolTestsBase { return request } + private func buildWebSocket() { + buildWebSocket(protocol: .graphql_ws) + } + // MARK: Initializer Tests func test__designatedInitializer__shouldSetRequestProtocolHeader() { diff --git a/Tests/ApolloTests/WebSocket/WSProtocolTestsBase.swift b/Tests/ApolloTests/WebSocket/WSProtocolTestsBase.swift index 8495457604..5e224654df 100644 --- a/Tests/ApolloTests/WebSocket/WSProtocolTestsBase.swift +++ b/Tests/ApolloTests/WebSocket/WSProtocolTestsBase.swift @@ -40,9 +40,9 @@ class WSProtocolTestsBase: XCTestCase { fatalError("Subclasses must override this property!") } - func buildWebSocket() { + func buildWebSocket(protocol: WebSocket.WSProtocol) { mockWebSocketDelegate = MockWebSocketDelegate() - mockWebSocket = MockWebSocket(request: urlRequest) + mockWebSocket = MockWebSocket(request: urlRequest, protocol: `protocol`) websocketTransport = WebSocketTransport(websocket: mockWebSocket, store: store) } diff --git a/Tests/ApolloTests/WebSocket/WebSocketTests.swift b/Tests/ApolloTests/WebSocket/WebSocketTests.swift index a6d610aa99..e7f264fa53 100644 --- a/Tests/ApolloTests/WebSocket/WebSocketTests.swift +++ b/Tests/ApolloTests/WebSocket/WebSocketTests.swift @@ -28,7 +28,10 @@ class WebSocketTests: XCTestCase { super.setUp() let store = ApolloStore() - let websocket = MockWebSocket(request:URLRequest(url: TestURL.mockServer.url)) + let websocket = MockWebSocket( + request:URLRequest(url: TestURL.mockServer.url), + protocol: .graphql_ws + ) networkTransport = WebSocketTransport(websocket: websocket, store: store) client = ApolloClient(networkTransport: networkTransport!, store: store) } @@ -133,7 +136,10 @@ class WebSocketTests: XCTestCase { let expectation = self.expectation(description: "Single Subscription with Custom Operation Message Id Creator") let store = ApolloStore() - let websocket = MockWebSocket(request:URLRequest(url: TestURL.mockServer.url)) + let websocket = MockWebSocket( + request:URLRequest(url: TestURL.mockServer.url), + protocol: .graphql_ws + ) networkTransport = WebSocketTransport(websocket: websocket, store: store, operationMessageIdCreator: CustomOperationMessageIdCreator()) client = ApolloClient(networkTransport: networkTransport!, store: store) diff --git a/Tests/ApolloTests/WebSocket/WebSocketTransportTests.swift b/Tests/ApolloTests/WebSocket/WebSocketTransportTests.swift index e062c45400..caa4d9440f 100644 --- a/Tests/ApolloTests/WebSocket/WebSocketTransportTests.swift +++ b/Tests/ApolloTests/WebSocket/WebSocketTransportTests.swift @@ -17,8 +17,10 @@ class WebSocketTransportTests: XCTestCase { var request = URLRequest(url: TestURL.mockServer.url) request.addValue("OldToken", forHTTPHeaderField: "Authorization") - self.webSocketTransport = WebSocketTransport(websocket: MockWebSocket(request: request), - store: ApolloStore()) + self.webSocketTransport = WebSocketTransport( + websocket: MockWebSocket(request: request, protocol: .graphql_ws), + store: ApolloStore() + ) self.webSocketTransport.updateHeaderValues(["Authorization": "UpdatedToken"]) @@ -28,9 +30,11 @@ class WebSocketTransportTests: XCTestCase { func testUpdateConnectingPayload() { let request = URLRequest(url: TestURL.mockServer.url) - self.webSocketTransport = WebSocketTransport(websocket: MockWebSocket(request: request), - store: ApolloStore(), - connectingPayload: ["Authorization": "OldToken"]) + self.webSocketTransport = WebSocketTransport( + websocket: MockWebSocket(request: request, protocol: .graphql_ws), + store: ApolloStore(), + connectingPayload: ["Authorization": "OldToken"] + ) let mockWebSocketDelegate = MockWebSocketDelegate() @@ -59,9 +63,11 @@ class WebSocketTransportTests: XCTestCase { func testCloseConnectionAndInit() { let request = URLRequest(url: TestURL.mockServer.url) - self.webSocketTransport = WebSocketTransport(websocket: MockWebSocket(request: request), - store: ApolloStore(), - connectingPayload: ["Authorization": "OldToken"]) + self.webSocketTransport = WebSocketTransport( + websocket: MockWebSocket(request: request, protocol: .graphql_ws), + store: ApolloStore(), + connectingPayload: ["Authorization": "OldToken"] + ) self.webSocketTransport.closeConnection() self.webSocketTransport.updateConnectingPayload(["Authorization": "UpdatedToken"]) self.webSocketTransport.initServer()