diff --git a/Netimobiledevice/Lockdown/ServiceConnection.cs b/Netimobiledevice/Lockdown/ServiceConnection.cs index 88736e8..1b3f72f 100644 --- a/Netimobiledevice/Lockdown/ServiceConnection.cs +++ b/Netimobiledevice/Lockdown/ServiceConnection.cs @@ -5,7 +5,6 @@ using Netimobiledevice.Plist; using Netimobiledevice.Usbmuxd; using System; -using System.Collections.Generic; using System.IO; using System.Linq; using System.Net; @@ -24,33 +23,38 @@ namespace Netimobiledevice.Lockdown /// public class ServiceConnection : IDisposable { - private const int MAX_READ_SIZE = 4096; + private const int MAX_READ_SIZE = 32768; /// /// The internal logger /// - private readonly ILogger logger; - private Stream networkStream; - private readonly byte[] receiveBuffer = new byte[MAX_READ_SIZE]; + private readonly ILogger _logger; + /// + /// The initial stream used for the ServiceConnection until the SSL stream starts, unless you specifically need to use this stream + /// you should use the Stream property instead + /// + private readonly NetworkStream _networkStream; + /// + /// The main stream once SSL is established, unless you specifically need to use this stream you should use the Stream + /// property instead + /// + private SslStream? _sslStream; public UsbmuxdDevice? MuxDevice { get; private set; } public bool IsConnected { get { - if (networkStream is NetworkStream ns) { - return ns.Socket.Connected; - } - return false; + return _networkStream.Socket.Connected; } } - public Stream Stream => networkStream; + public Stream Stream => _sslStream != null ? _sslStream : _networkStream; private ServiceConnection(Socket sock, ILogger logger, UsbmuxdDevice? muxDevice = null) { - this.logger = logger; + _logger = logger; + _networkStream = new NetworkStream(sock, true); - networkStream = new NetworkStream(sock, true); // Usbmux connections contain additional information associated with the current connection MuxDevice = muxDevice; } @@ -73,7 +77,7 @@ internal static ServiceConnection CreateUsingUsbmux(string udid, ushort port, Us throw new NoDeviceConnectedException(); } Socket sock = targetDevice.Connect(port, usbmuxAddress: usbmuxAddress, logger); - return new ServiceConnection(sock, logger, targetDevice); + return new ServiceConnection(sock, logger ?? NullLogger.Instance, targetDevice); } private bool UserCertificateValidationCallback(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors) @@ -83,22 +87,22 @@ private bool UserCertificateValidationCallback(object sender, X509Certificate? c public void Close() { - networkStream.Close(); + Stream.Close(); } public void Dispose() { Close(); - networkStream.Dispose(); + Stream.Dispose(); GC.SuppressFinalize(this); } public byte[] Receive(int length = 4096) { if (length <= 0) { - return Array.Empty(); + return []; } - List buffer = new List(); + byte[] buffer = new byte[length]; int totalBytesRead = 0; while (totalBytesRead < length) { @@ -108,24 +112,26 @@ public byte[] Receive(int length = 4096) readSize = MAX_READ_SIZE; } - int bytesRead = networkStream.Read(receiveBuffer, 0, readSize); - if (bytesRead == 0) { // If we don't get any bytes, the network connection was broken + int bytesRead = Stream.Read(buffer, totalBytesRead, readSize); + if (bytesRead == 0) { + _logger.LogError("Read zero bytes so the connection has been broken"); break; } totalBytesRead += bytesRead; - - buffer.AddRange(receiveBuffer.Take(bytesRead)); } - return buffer.ToArray(); + if (totalBytesRead < buffer.Length) { + return buffer.Take(totalBytesRead).ToArray(); + } + return buffer; } public async Task ReceiveAsync(int length, CancellationToken cancellationToken) { if (length <= 0) { - return Array.Empty(); + return []; } - List buffer = new List(); + byte[] buffer = new byte[length]; int totalBytesRead = 0; while (totalBytesRead < length) { @@ -136,13 +142,13 @@ public async Task ReceiveAsync(int length, CancellationToken cancellatio } int bytesRead; - if (networkStream.ReadTimeout != -1) { + if (Stream.ReadTimeout != -1) { CancellationTokenSource localTaskComplete = new CancellationTokenSource(); - Task result = networkStream.ReadAsync(receiveBuffer, 0, readSize, localTaskComplete.Token); - Task delay = Task.Delay(networkStream.ReadTimeout, localTaskComplete.Token); + Task result = Stream.ReadAsync(buffer, totalBytesRead, readSize, localTaskComplete.Token); + Task delay = Task.Delay(Stream.ReadTimeout, localTaskComplete.Token); - await Task.WhenAny(result, delay).WaitAsync(cancellationToken); + await Task.WhenAny(result, delay).WaitAsync(cancellationToken).ConfigureAwait(false); if (cancellationToken.IsCancellationRequested) { localTaskComplete.Cancel(); } @@ -150,20 +156,23 @@ public async Task ReceiveAsync(int length, CancellationToken cancellatio localTaskComplete.Cancel(); throw new TimeoutException("Timeout waiting for message from service"); } - bytesRead = await result; - if (bytesRead == 0) { // If we don't get any bytes, the network connection was broken + bytesRead = await result.ConfigureAwait(false); + if (bytesRead == 0) { + _logger.LogError("Read zero bytes so the connection has been broken"); break; } } else { - bytesRead = await networkStream.ReadAsync(receiveBuffer.AsMemory(0, readSize), cancellationToken); + bytesRead = await Stream.ReadAsync(buffer.AsMemory(totalBytesRead, readSize), cancellationToken).ConfigureAwait(false); } totalBytesRead += bytesRead; - buffer.AddRange(receiveBuffer.Take(bytesRead)); } - return buffer.ToArray(); + if (totalBytesRead < buffer.Length) { + return buffer.Take(totalBytesRead).ToArray(); + } + return buffer; } public PropertyNode? ReceivePlist() @@ -177,11 +186,11 @@ public async Task ReceiveAsync(int length, CancellationToken cancellatio public async Task ReceivePlistAsync(CancellationToken cancellationToken) { - byte[] plistBytes = await ReceivePrefixedAsync(cancellationToken); + byte[] plistBytes = await ReceivePrefixedAsync(cancellationToken).ConfigureAwait(false); if (plistBytes.Length == 0) { return null; } - return await PropertyList.LoadFromByteArrayAsync(plistBytes); + return await PropertyList.LoadFromByteArrayAsync(plistBytes).ConfigureAwait(false); } /// @@ -192,7 +201,7 @@ public byte[] ReceivePrefixed() { byte[] sizeBytes = Receive(4); if (sizeBytes.Length != 4) { - return Array.Empty(); + return []; } int size = EndianBitConverter.BigEndian.ToInt32(sizeBytes, 0); @@ -205,23 +214,23 @@ public byte[] ReceivePrefixed() /// The data without the u32 field length as a byte array public async Task ReceivePrefixedAsync(CancellationToken cancellationToken = default) { - byte[] sizeBytes = await ReceiveAsync(4, cancellationToken); + byte[] sizeBytes = await ReceiveAsync(4, cancellationToken).ConfigureAwait(false); if (sizeBytes.Length != 4) { - return Array.Empty(); + return []; } int size = EndianBitConverter.BigEndian.ToInt32(sizeBytes, 0); - return await ReceiveAsync(size, cancellationToken); + return await ReceiveAsync(size, cancellationToken).ConfigureAwait(false); } public void Send(byte[] data) { - networkStream.Write(data); + Stream.Write(data); } public async Task SendAsync(byte[] data, CancellationToken cancellationToken) { - await networkStream.WriteAsync(data, cancellationToken); + await Stream.WriteAsync(data, cancellationToken).ConfigureAwait(false); } public void SendPlist(PropertyNode data, PlistFormat format = PlistFormat.Xml) @@ -229,10 +238,8 @@ public void SendPlist(PropertyNode data, PlistFormat format = PlistFormat.Xml) byte[] plistBytes = PropertyList.SaveAsByteArray(data, format); byte[] lengthBytes = BitConverter.GetBytes(EndianBitConverter.BigEndian.ToInt32(BitConverter.GetBytes(plistBytes.Length), 0)); - List payload = new List(); - payload.AddRange(lengthBytes); - payload.AddRange(plistBytes); - Send(payload.ToArray()); + Send(lengthBytes); + Send(plistBytes); } public async Task SendPlistAsync(PropertyNode data, PlistFormat format = PlistFormat.Xml, CancellationToken cancellationToken = default) @@ -240,10 +247,8 @@ public async Task SendPlistAsync(PropertyNode data, PlistFormat format = PlistFo byte[] plistBytes = PropertyList.SaveAsByteArray(data, format); byte[] lengthBytes = BitConverter.GetBytes(EndianBitConverter.BigEndian.ToInt32(BitConverter.GetBytes(plistBytes.Length), 0)); - List payload = new List(); - payload.AddRange(lengthBytes); - payload.AddRange(plistBytes); - await SendAsync(payload.ToArray(), cancellationToken); + await SendAsync(lengthBytes, cancellationToken).ConfigureAwait(false); + await SendAsync(plistBytes, cancellationToken).ConfigureAwait(false); } public PropertyNode? SendReceivePlist(PropertyNode data) @@ -254,8 +259,8 @@ public async Task SendPlistAsync(PropertyNode data, PlistFormat format = PlistFo public async Task SendReceivePlistAsync(PropertyNode data, CancellationToken cancellationToken) { - await SendPlistAsync(data, cancellationToken: cancellationToken); - return await ReceivePlistAsync(cancellationToken); + await SendPlistAsync(data, cancellationToken: cancellationToken).ConfigureAwait(false); + return await ReceivePlistAsync(cancellationToken).ConfigureAwait(false); } /// @@ -264,8 +269,8 @@ public async Task SendPlistAsync(PropertyNode data, PlistFormat format = PlistFo /// A value in milliseconds that detemines how long the service connection will wait before timing out public void SetTimeout(int timeout = -1) { - networkStream.ReadTimeout = timeout; - networkStream.WriteTimeout = timeout; + Stream.ReadTimeout = timeout; + Stream.WriteTimeout = timeout; } public void StartSSL(byte[] certData, byte[] privateKeyData) @@ -274,19 +279,20 @@ public void StartSSL(byte[] certData, byte[] privateKeyData) string privateKeyText = Encoding.UTF8.GetString(privateKeyData); X509Certificate2 cert = X509Certificate2.CreateFromPem(certText, privateKeyText); - networkStream.Flush(); + if (_networkStream == null) { + throw new InvalidOperationException("Network stream is null"); + } + _networkStream.Flush(); - SslStream sslStream = new SslStream(networkStream, true, UserCertificateValidationCallback, null, EncryptionPolicy.RequireEncryption); + _sslStream = new SslStream(_networkStream, true, UserCertificateValidationCallback, null, EncryptionPolicy.RequireEncryption); try { // NOTE: For some reason we need to re-export and then import the cert again ¯\_(ツ)_/¯ // see this for more details: https://github.com/dotnet/runtime/issues/45680 - sslStream.AuthenticateAsClient(string.Empty, [new X509Certificate2(cert.Export(X509ContentType.Pkcs12))], SslProtocols.None, false); + _sslStream.AuthenticateAsClient(string.Empty, [new X509Certificate2(cert.Export(X509ContentType.Pkcs12))], SslProtocols.None, false); } catch (AuthenticationException ex) { - logger.LogError(ex, "SSL authentication failed"); + _logger.LogError(ex, "SSL authentication failed"); } - - networkStream = sslStream; } } }