Skip to content

Commit 402dd16

Browse files
fix: graphql_transport_ws protocol should send 'complete' to end subscription (#2320)
* fix: graphql_transport_ws protocol should send 'complete' to end subscription * fix: Require MockWebSocket to be initialized with a graphql subscriptions protocol * tests: Fix StarWarsSubscriptionTests to use updated MockWebSocket initializer
1 parent 51c81bd commit 402dd16

8 files changed

Lines changed: 66 additions & 21 deletions

File tree

Sources/ApolloTestSupport/MockWebSocket.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ public class MockWebSocket: WebSocketClient {
88
public var delegate: WebSocketClientDelegate? = nil
99
public var isConnected: Bool = false
1010

11-
public required init(request: URLRequest) {
11+
public required init(request: URLRequest, protocol: WebSocket.WSProtocol) {
1212
self.request = request
13+
14+
self.request.setValue(`protocol`.description, forHTTPHeaderField: WebSocket.Constants.headerWSProtocolName)
1315
}
1416

1517
open func reportDidConnect() {

Sources/ApolloWebSocket/WebSocketTransport.swift

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,14 @@ public class WebSocketTransport {
280280
autoPersistQuery: false)
281281
let identifier = operationMessageIdCreator.requestId()
282282

283-
var type: OperationMessage.Types = .start
284-
if case WebSocket.WSProtocol.graphql_transport_ws.description = websocket.request.value(forHTTPHeaderField: WebSocket.Constants.headerWSProtocolName) {
285-
type = .subscribe
283+
let messageType: OperationMessage.Types
284+
switch websocket.request.wsProtocol {
285+
case .graphql_ws: messageType = .start
286+
case .graphql_transport_ws: messageType = .subscribe
287+
default: return nil
286288
}
287289

288-
guard let message = OperationMessage(payload: body, id: identifier, type: type).rawMessage else {
290+
guard let message = OperationMessage(payload: body, id: identifier, type: messageType).rawMessage else {
289291
return nil
290292
}
291293

@@ -302,7 +304,13 @@ public class WebSocketTransport {
302304
}
303305

304306
public func unsubscribe(_ subscriptionId: String) {
305-
let str = OperationMessage(id: subscriptionId, type: .stop).rawMessage
307+
let messageType: OperationMessage.Types
308+
switch websocket.request.wsProtocol {
309+
case .graphql_transport_ws: messageType = .complete
310+
default: messageType = .stop
311+
}
312+
313+
let str = OperationMessage(id: subscriptionId, type: messageType).rawMessage
306314

307315
processingQueue.async {
308316
if let str = str {
@@ -359,6 +367,20 @@ public class WebSocketTransport {
359367
}
360368
}
361369

370+
extension URLRequest {
371+
fileprivate var wsProtocol: WebSocket.WSProtocol? {
372+
guard let header = value(forHTTPHeaderField: WebSocket.Constants.headerWSProtocolName) else {
373+
return nil
374+
}
375+
376+
switch header {
377+
case WebSocket.WSProtocol.graphql_transport_ws.description: return .graphql_transport_ws
378+
case WebSocket.WSProtocol.graphql_ws.description: return .graphql_ws
379+
default: return nil
380+
}
381+
}
382+
}
383+
362384
// MARK: - NetworkTransport conformance
363385

364386
extension WebSocketTransport: NetworkTransport {

Tests/ApolloServerIntegrationTests/StarWarsSubscriptionTests.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,8 @@ class StarWarsSubscriptionTests: XCTestCase {
411411
func testConcurrentConnectAndCloseConnection() {
412412
let webSocketTransport = WebSocketTransport(
413413
websocket: MockWebSocket(
414-
request: URLRequest(url: TestServerURL.starWarsWebSocket.url)
414+
request: URLRequest(url: TestServerURL.starWarsWebSocket.url),
415+
protocol: .graphql_ws
415416
),
416417
store: ApolloStore()
417418
)

Tests/ApolloTests/WebSocket/GraphqlTransportWsProtocolTests.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ class GraphqlTransportWsProtocolTests: WSProtocolTestsBase {
1616
return request
1717
}
1818

19+
private func buildWebSocket() {
20+
buildWebSocket(protocol: .graphql_transport_ws)
21+
}
22+
1923
// MARK: Initializer Tests
2024

2125
func test__designatedInitializer__shouldSetRequestProtocolHeader() {
@@ -123,7 +127,7 @@ class GraphqlTransportWsProtocolTests: WSProtocolTestsBase {
123127
}
124128
}
125129

126-
func test__messaging__givenSubscriptionCancel_shouldSendStop() {
130+
func test__messaging__givenSubscriptionCancel_shouldSendComplete() {
127131
// given
128132
buildWebSocket()
129133
buildClient()
@@ -136,7 +140,7 @@ class GraphqlTransportWsProtocolTests: WSProtocolTestsBase {
136140
waitUntil { done in
137141
self.mockWebSocketDelegate.didReceiveMessage = { message in
138142
// then
139-
let expected = OperationMessage(id: "1", type: .stop).rawMessage!
143+
let expected = OperationMessage(id: "1", type: .complete).rawMessage!
140144
if message == expected {
141145
done()
142146
}

Tests/ApolloTests/WebSocket/GraphqlWsProtocolTests.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ class GraphqlWsProtocolTests: WSProtocolTestsBase {
1616
return request
1717
}
1818

19+
private func buildWebSocket() {
20+
buildWebSocket(protocol: .graphql_ws)
21+
}
22+
1923
// MARK: Initializer Tests
2024

2125
func test__designatedInitializer__shouldSetRequestProtocolHeader() {

Tests/ApolloTests/WebSocket/WSProtocolTestsBase.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ class WSProtocolTestsBase: XCTestCase {
4040
fatalError("Subclasses must override this property!")
4141
}
4242

43-
func buildWebSocket() {
43+
func buildWebSocket(protocol: WebSocket.WSProtocol) {
4444
mockWebSocketDelegate = MockWebSocketDelegate()
45-
mockWebSocket = MockWebSocket(request: urlRequest)
45+
mockWebSocket = MockWebSocket(request: urlRequest, protocol: `protocol`)
4646
websocketTransport = WebSocketTransport(websocket: mockWebSocket, store: store)
4747
}
4848

Tests/ApolloTests/WebSocket/WebSocketTests.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ class WebSocketTests: XCTestCase {
2828
super.setUp()
2929

3030
let store = ApolloStore()
31-
let websocket = MockWebSocket(request:URLRequest(url: TestURL.mockServer.url))
31+
let websocket = MockWebSocket(
32+
request:URLRequest(url: TestURL.mockServer.url),
33+
protocol: .graphql_ws
34+
)
3235
networkTransport = WebSocketTransport(websocket: websocket, store: store)
3336
client = ApolloClient(networkTransport: networkTransport!, store: store)
3437
}
@@ -133,7 +136,10 @@ class WebSocketTests: XCTestCase {
133136
let expectation = self.expectation(description: "Single Subscription with Custom Operation Message Id Creator")
134137

135138
let store = ApolloStore()
136-
let websocket = MockWebSocket(request:URLRequest(url: TestURL.mockServer.url))
139+
let websocket = MockWebSocket(
140+
request:URLRequest(url: TestURL.mockServer.url),
141+
protocol: .graphql_ws
142+
)
137143
networkTransport = WebSocketTransport(websocket: websocket, store: store, operationMessageIdCreator: CustomOperationMessageIdCreator())
138144
client = ApolloClient(networkTransport: networkTransport!, store: store)
139145

Tests/ApolloTests/WebSocket/WebSocketTransportTests.swift

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ class WebSocketTransportTests: XCTestCase {
1717
var request = URLRequest(url: TestURL.mockServer.url)
1818
request.addValue("OldToken", forHTTPHeaderField: "Authorization")
1919

20-
self.webSocketTransport = WebSocketTransport(websocket: MockWebSocket(request: request),
21-
store: ApolloStore())
20+
self.webSocketTransport = WebSocketTransport(
21+
websocket: MockWebSocket(request: request, protocol: .graphql_ws),
22+
store: ApolloStore()
23+
)
2224

2325
self.webSocketTransport.updateHeaderValues(["Authorization": "UpdatedToken"])
2426

@@ -28,9 +30,11 @@ class WebSocketTransportTests: XCTestCase {
2830
func testUpdateConnectingPayload() {
2931
let request = URLRequest(url: TestURL.mockServer.url)
3032

31-
self.webSocketTransport = WebSocketTransport(websocket: MockWebSocket(request: request),
32-
store: ApolloStore(),
33-
connectingPayload: ["Authorization": "OldToken"])
33+
self.webSocketTransport = WebSocketTransport(
34+
websocket: MockWebSocket(request: request, protocol: .graphql_ws),
35+
store: ApolloStore(),
36+
connectingPayload: ["Authorization": "OldToken"]
37+
)
3438

3539
let mockWebSocketDelegate = MockWebSocketDelegate()
3640

@@ -59,9 +63,11 @@ class WebSocketTransportTests: XCTestCase {
5963
func testCloseConnectionAndInit() {
6064
let request = URLRequest(url: TestURL.mockServer.url)
6165

62-
self.webSocketTransport = WebSocketTransport(websocket: MockWebSocket(request: request),
63-
store: ApolloStore(),
64-
connectingPayload: ["Authorization": "OldToken"])
66+
self.webSocketTransport = WebSocketTransport(
67+
websocket: MockWebSocket(request: request, protocol: .graphql_ws),
68+
store: ApolloStore(),
69+
connectingPayload: ["Authorization": "OldToken"]
70+
)
6571
self.webSocketTransport.closeConnection()
6672
self.webSocketTransport.updateConnectingPayload(["Authorization": "UpdatedToken"])
6773
self.webSocketTransport.initServer()

0 commit comments

Comments
 (0)