diff --git a/EssentialFeed/EssentialFeedTests/Shared API Infra/Helpers/URLProtocolStub.swift b/EssentialFeed/EssentialFeedTests/Shared API Infra/Helpers/URLProtocolStub.swift index f96a591f..ea80c6c5 100644 --- a/EssentialFeed/EssentialFeedTests/Shared API Infra/Helpers/URLProtocolStub.swift +++ b/EssentialFeed/EssentialFeedTests/Shared API Infra/Helpers/URLProtocolStub.swift @@ -9,6 +9,7 @@ class URLProtocolStub: URLProtocol { let data: Data? let response: URLResponse? let error: Error? + let shouldCancelTask: Bool let requestObserver: ((URLRequest) -> Void)? } @@ -21,11 +22,15 @@ class URLProtocolStub: URLProtocol { private static let queue = DispatchQueue(label: "URLProtocolStub.queue") static func stub(data: Data?, response: URLResponse?, error: Error?) { - stub = Stub(data: data, response: response, error: error, requestObserver: nil) + stub = Stub(data: data, response: response, error: error, shouldCancelTask: false, requestObserver: nil) } - static func observeRequests(observer: @escaping (URLRequest) -> Void) { - stub = Stub(data: nil, response: nil, error: nil, requestObserver: observer) + static func cancelIncomingTasks() { + stub = Stub(data: nil, response: nil, error: nil, shouldCancelTask: true, requestObserver: nil) + } + + static func observeRequests(shouldFinish: Bool = true, observer: @escaping (URLRequest) -> Void) { + stub = Stub(data: nil, response: nil, error: nil, shouldCancelTask: false, requestObserver: observer) } static func removeStub() { @@ -43,6 +48,11 @@ class URLProtocolStub: URLProtocol { override func startLoading() { guard let stub = URLProtocolStub.stub else { return } + if stub.shouldCancelTask { + task?.cancel() + return + } + if let data = stub.data { client?.urlProtocol(self, didLoad: data) } diff --git a/EssentialFeed/EssentialFeedTests/Shared API Infra/URLSessionHTTPClientTests.swift b/EssentialFeed/EssentialFeedTests/Shared API Infra/URLSessionHTTPClientTests.swift index 84bc38eb..13e08475 100644 --- a/EssentialFeed/EssentialFeedTests/Shared API Infra/URLSessionHTTPClientTests.swift +++ b/EssentialFeed/EssentialFeedTests/Shared API Infra/URLSessionHTTPClientTests.swift @@ -29,11 +29,9 @@ class URLSessionHTTPClientTests: XCTestCase { } func test_cancelGetFromURLTask_cancelsURLRequest() { - let exp = expectation(description: "Wait for request") - URLProtocolStub.observeRequests { _ in exp.fulfill() } + URLProtocolStub.cancelIncomingTasks() - let receivedError = resultErrorFor(taskHandler: { $0.cancel() }) as NSError? - wait(for: [exp], timeout: 1.0) + let receivedError = resultErrorFor() as NSError? XCTAssertEqual(receivedError?.code, URLError.cancelled.rawValue) } @@ -104,8 +102,8 @@ class URLSessionHTTPClientTests: XCTestCase { } } - private func resultErrorFor(_ values: (data: Data?, response: URLResponse?, error: Error?)? = nil, taskHandler: (HTTPClientTask) -> Void = { _ in }, file: StaticString = #filePath, line: UInt = #line) -> Error? { - let result = resultFor(values, taskHandler: taskHandler, file: file, line: line) + private func resultErrorFor(_ values: (data: Data?, response: URLResponse?, error: Error?)? = nil, file: StaticString = #filePath, line: UInt = #line) -> Error? { + let result = resultFor(values, file: file, line: line) switch result { case let .failure(error): @@ -116,17 +114,17 @@ class URLSessionHTTPClientTests: XCTestCase { } } - private func resultFor(_ values: (data: Data?, response: URLResponse?, error: Error?)?, taskHandler: (HTTPClientTask) -> Void = { _ in }, file: StaticString = #filePath, line: UInt = #line) -> HTTPClient.Result { + private func resultFor(_ values: (data: Data?, response: URLResponse?, error: Error?)?, file: StaticString = #filePath, line: UInt = #line) -> HTTPClient.Result { values.map { URLProtocolStub.stub(data: $0, response: $1, error: $2) } let sut = makeSUT(file: file, line: line) let exp = expectation(description: "Wait for completion") var receivedResult: HTTPClient.Result! - taskHandler(sut.get(from: anyURL()) { result in + sut.get(from: anyURL()) { result in receivedResult = result exp.fulfill() - }) + } wait(for: [exp], timeout: 1.0) return receivedResult