Skip to content

Commit 08ae91a

Browse files
Merge pull request #880 from aivcec/fix/subscriptions-data-races
Fixing data races in subscriptions
2 parents f0ceece + 33c3649 commit 08ae91a

6 files changed

Lines changed: 223 additions & 33 deletions

File tree

ApolloWebSocket.xcodeproj/project.pbxproj

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
9F28B6D520720F2F00144A00 /* Apollo.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 9F28B6D420720F2F00144A00 /* Apollo.framework */; };
2424
9F28B6D920720FD200144A00 /* ApolloTestSupport.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 9F28B6D820720FD100144A00 /* ApolloTestSupport.framework */; };
2525
9F28B6DB2072101200144A00 /* StarWarsAPI.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 9F28B6DA2072101200144A00 /* StarWarsAPI.framework */; };
26+
D1ACF61D23715AF30042E200 /* Atomic.swift in Sources */ = {isa = PBXBuildFile; fileRef = D1ACF61B23715AF30042E200 /* Atomic.swift */; };
2627
/* End PBXBuildFile section */
2728

2829
/* Begin PBXContainerItemProxy section */
@@ -73,6 +74,7 @@
7374
9F28B6D420720F2F00144A00 /* Apollo.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = Apollo.framework; sourceTree = BUILT_PRODUCTS_DIR; };
7475
9F28B6D820720FD100144A00 /* ApolloTestSupport.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = ApolloTestSupport.framework; sourceTree = BUILT_PRODUCTS_DIR; };
7576
9F28B6DA2072101200144A00 /* StarWarsAPI.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = StarWarsAPI.framework; sourceTree = BUILT_PRODUCTS_DIR; };
77+
D1ACF61B23715AF30042E200 /* Atomic.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Atomic.swift; sourceTree = "<group>"; };
7678
/* End PBXFileReference section */
7779

7880
/* Begin PBXFrameworksBuildPhase section */
@@ -141,6 +143,7 @@
141143
9B1CCDE223611606007C9032 /* WebSocketTask.swift */,
142144
7270746B206D111A00C131F6 /* WebSocketTransport.swift */,
143145
7270746C206D111A00C131F6 /* Info.plist */,
146+
D1ACF61923715AF30042E200 /* Utilities */,
144147
);
145148
name = ApolloWebSocket;
146149
path = Sources/ApolloWebSocket;
@@ -181,6 +184,14 @@
181184
name = Products;
182185
sourceTree = "<group>";
183186
};
187+
D1ACF61923715AF30042E200 /* Utilities */ = {
188+
isa = PBXGroup;
189+
children = (
190+
D1ACF61B23715AF30042E200 /* Atomic.swift */,
191+
);
192+
path = Utilities;
193+
sourceTree = "<group>";
194+
};
184195
/* End PBXGroup section */
185196

186197
/* Begin PBXNativeTarget section */
@@ -310,6 +321,7 @@
310321
9B1CCDDF236110C3007C9032 /* WebSocketError.swift in Sources */,
311322
7270746D206D111A00C131F6 /* SplitNetworkTransport.swift in Sources */,
312323
9B1CCDE323611606007C9032 /* WebSocketTask.swift in Sources */,
324+
D1ACF61D23715AF30042E200 /* Atomic.swift in Sources */,
313325
7270746E206D111A00C131F6 /* WebSocketTransport.swift in Sources */,
314326
9B1CCDE123611580007C9032 /* OperationMessage.swift in Sources */,
315327
9B1CCDDB23610CDC007C9032 /* ApolloWebSocket.swift in Sources */,

Sources/ApolloWebSocket/ApolloWebSocket.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ public protocol ApolloWebSocketClient: WebSocketClient {
1414

1515
/// The URLRequest used on connection.
1616
var request: URLRequest { get set }
17+
18+
/// Queue where the callbacks are executed
19+
var callbackQueue: DispatchQueue { get set }
1720
}
1821

1922
// MARK: - WebSocket
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import Foundation
2+
3+
class Atomic<T> {
4+
private let lock = NSLock()
5+
private var _value: T
6+
7+
init(_ value: T) {
8+
_value = value
9+
}
10+
11+
var value: T {
12+
get {
13+
lock.lock()
14+
defer { lock.unlock() }
15+
16+
return _value
17+
}
18+
set {
19+
lock.lock()
20+
defer { lock.unlock() }
21+
22+
_value = newValue
23+
}
24+
}
25+
}
26+
27+
extension Atomic where T == Int {
28+
29+
func increment() -> T {
30+
lock.lock()
31+
defer { lock.unlock() }
32+
33+
_value += 1
34+
return _value
35+
}
36+
}

Sources/ApolloWebSocket/WebSocketTransport.swift

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ public class WebSocketTransport {
2525
public static var provider: ApolloWebSocketClient.Type = ApolloWebSocket.self
2626
public weak var delegate: WebSocketTransportDelegate?
2727

28-
var reconnect = false
28+
let reconnect: Atomic<Bool> = Atomic(false)
2929
var websocket: ApolloWebSocketClient
30-
var error: Error? = nil
30+
let error: Atomic<Error?> = Atomic(nil)
3131
let serializationFormat = JSONSerializationFormat.self
3232
private let requestCreator: RequestCreator
3333

@@ -40,10 +40,11 @@ public class WebSocketTransport {
4040

4141
private var subscribers = [String: (Result<JSONObject, Error>) -> Void]()
4242
private var subscriptions : [String: String] = [:]
43+
private let processingQueue = DispatchQueue(label: "com.apollographql.WebSocketTransport")
4344

4445
private let sendOperationIdentifiers: Bool
4546
private let reconnectionInterval: TimeInterval
46-
fileprivate var sequenceNumber = 0
47+
fileprivate let sequenceNumberCounter = Atomic<Int>(0)
4748
fileprivate var reconnected = false
4849

4950
/// NOTE: Setting this won't override immediately if the socket is still connected, only on reconnection.
@@ -87,6 +88,7 @@ public class WebSocketTransport {
8788
self.websocket.request.setValue(self.clientVersion, forHTTPHeaderField: WebSocketTransport.headerFieldNameClientVersion)
8889
self.websocket.delegate = self
8990
self.websocket.connect()
91+
self.websocket.callbackQueue = processingQueue
9092
}
9193

9294
public func isConnected() -> Bool {
@@ -174,7 +176,7 @@ public class WebSocketTransport {
174176
}
175177

176178
public func initServer(reconnect: Bool = true) {
177-
self.reconnect = reconnect
179+
self.reconnect.value = reconnect
178180
self.acked = false
179181

180182
if let str = OperationMessage(payload: self.connectingPayload, type: .connectionInit).rawMessage {
@@ -184,12 +186,17 @@ public class WebSocketTransport {
184186
}
185187

186188
public func closeConnection() {
187-
self.reconnect = false
188-
if let str = OperationMessage(type: .connectionTerminate).rawMessage {
189-
write(str)
189+
self.reconnect.value = false
190+
191+
let str = OperationMessage(type: .connectionTerminate).rawMessage
192+
processingQueue.async {
193+
if let str = str {
194+
self.write(str)
195+
}
196+
197+
self.queue.removeAll()
198+
self.subscriptions.removeAll()
190199
}
191-
self.queue.removeAll()
192-
self.subscriptions.removeAll()
193200
}
194201

195202
private func write(_ str: String, force forced: Bool = false, id: Int? = nil) {
@@ -213,43 +220,44 @@ public class WebSocketTransport {
213220
websocket.delegate = nil
214221
}
215222

216-
private func nextSequenceNumber() -> Int {
217-
sequenceNumber += 1
218-
return sequenceNumber
219-
}
220-
221223
func sendHelper<Operation: GraphQLOperation>(operation: Operation, resultHandler: @escaping (_ result: Result<JSONObject, Error>) -> Void) -> String? {
222224
let body = requestCreator.requestBody(for: operation, sendOperationIdentifiers: self.sendOperationIdentifiers)
223-
let sequenceNumber = "\(nextSequenceNumber())"
225+
let sequenceNumber = "\(sequenceNumberCounter.increment())"
224226

225227
guard let message = OperationMessage(payload: body, id: sequenceNumber).rawMessage else {
226228
return nil
227229
}
228-
229-
write(message)
230+
231+
processingQueue.async {
232+
self.write(message)
230233

231-
subscribers[sequenceNumber] = resultHandler
232-
if operation.operationType == .subscription {
233-
subscriptions[sequenceNumber] = message
234+
self.subscribers[sequenceNumber] = resultHandler
235+
if operation.operationType == .subscription {
236+
self.subscriptions[sequenceNumber] = message
237+
}
234238
}
235239

236240
return sequenceNumber
237241
}
238242

239243
public func unsubscribe(_ subscriptionId: String) {
240-
if let str = OperationMessage(id: subscriptionId, type: .stop).rawMessage {
241-
write(str)
244+
let str = OperationMessage(id: subscriptionId, type: .stop).rawMessage
245+
246+
processingQueue.async {
247+
if let str = str {
248+
self.write(str)
249+
}
250+
self.subscribers.removeValue(forKey: subscriptionId)
251+
self.subscriptions.removeValue(forKey: subscriptionId)
242252
}
243-
subscribers.removeValue(forKey: subscriptionId)
244-
subscriptions.removeValue(forKey: subscriptionId)
245253
}
246254
}
247255

248256
// MARK: - HTTPNetworkTransport conformance
249257

250258
extension WebSocketTransport: NetworkTransport {
251259
public func send<Operation>(operation: Operation, completionHandler: @escaping (_ result: Result<GraphQLResponse<Operation>,Error>) -> Void) -> Cancellable {
252-
if let error = self.error {
260+
if let error = self.error.value {
253261
completionHandler(.failure(error))
254262
return EmptyCancellable()
255263
}
@@ -271,7 +279,7 @@ extension WebSocketTransport: NetworkTransport {
271279
extension WebSocketTransport: WebSocketDelegate {
272280

273281
public func websocketDidConnect(socket: WebSocketClient) {
274-
self.error = nil
282+
self.error.value = nil
275283
initServer()
276284
if reconnected {
277285
self.delegate?.webSocketTransportDidReconnect(self)
@@ -290,16 +298,16 @@ extension WebSocketTransport: WebSocketDelegate {
290298
public func websocketDidDisconnect(socket: WebSocketClient, error: Error?) {
291299
// report any error to all subscribers
292300
if let error = error {
293-
self.error = WebSocketError(payload: nil, error: error, kind: .networkError)
301+
self.error.value = WebSocketError(payload: nil, error: error, kind: .networkError)
294302
self.notifyErrorAllHandlers(error)
295303
} else {
296-
self.error = nil
304+
self.error.value = nil
297305
}
298306

299-
self.delegate?.webSocketTransport(self, didDisconnectWithError: self.error)
307+
self.delegate?.webSocketTransport(self, didDisconnectWithError: self.error.value)
300308
acked = false // need new connect and ack before sending
301309

302-
if reconnect {
310+
if reconnect.value {
303311
DispatchQueue.main.asyncAfter(deadline: .now() + reconnectionInterval) {
304312
self.websocket.connect()
305313
}

Tests/ApolloWebsocketTests/MockWebSocket.swift

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ import Starscream
22
@testable import ApolloWebSocket
33

44
class MockWebSocket: ApolloWebSocketClient {
5+
6+
var callbackQueue: DispatchQueue = DispatchQueue.main
7+
58
var pongDelegate: WebSocketPongDelegate?
69
var request: URLRequest
710

@@ -15,8 +18,16 @@ class MockWebSocket: ApolloWebSocketClient {
1518
self.request = URLRequest(url: URL(string: "http://localhost:8080")!)
1619
}
1720

21+
open func reportDidConnect() {
22+
callbackQueue.async {
23+
self.delegate?.websocketDidConnect(socket: self)
24+
}
25+
}
26+
1827
open func write(string: String, completion: (() -> ())?) {
19-
delegate?.websocketDidReceiveMessage(socket: self, text: string)
28+
callbackQueue.async {
29+
self.delegate?.websocketDidReceiveMessage(socket: self, text: string)
30+
}
2031
}
2132

2233
open func write(data: Data, completion: (() -> ())?) {

0 commit comments

Comments
 (0)