diff --git a/Netimobiledevice/Lockdown/ServiceConnection.cs b/Netimobiledevice/Lockdown/ServiceConnection.cs
index 88736e8..1b3f72f 100644
--- a/Netimobiledevice/Lockdown/ServiceConnection.cs
+++ b/Netimobiledevice/Lockdown/ServiceConnection.cs
@@ -5,7 +5,6 @@
using Netimobiledevice.Plist;
using Netimobiledevice.Usbmuxd;
using System;
-using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
@@ -24,33 +23,38 @@ namespace Netimobiledevice.Lockdown
///
public class ServiceConnection : IDisposable
{
- private const int MAX_READ_SIZE = 4096;
+ private const int MAX_READ_SIZE = 32768;
///
/// The internal logger
///
- private readonly ILogger logger;
- private Stream networkStream;
- private readonly byte[] receiveBuffer = new byte[MAX_READ_SIZE];
+ private readonly ILogger _logger;
+ ///
+ /// The initial stream used for the ServiceConnection until the SSL stream starts, unless you specifically need to use this stream
+ /// you should use the Stream property instead
+ ///
+ private readonly NetworkStream _networkStream;
+ ///
+ /// The main stream once SSL is established, unless you specifically need to use this stream you should use the Stream
+ /// property instead
+ ///
+ private SslStream? _sslStream;
public UsbmuxdDevice? MuxDevice { get; private set; }
public bool IsConnected {
get {
- if (networkStream is NetworkStream ns) {
- return ns.Socket.Connected;
- }
- return false;
+ return _networkStream.Socket.Connected;
}
}
- public Stream Stream => networkStream;
+ public Stream Stream => _sslStream != null ? _sslStream : _networkStream;
private ServiceConnection(Socket sock, ILogger logger, UsbmuxdDevice? muxDevice = null)
{
- this.logger = logger;
+ _logger = logger;
+ _networkStream = new NetworkStream(sock, true);
- networkStream = new NetworkStream(sock, true);
// Usbmux connections contain additional information associated with the current connection
MuxDevice = muxDevice;
}
@@ -73,7 +77,7 @@ internal static ServiceConnection CreateUsingUsbmux(string udid, ushort port, Us
throw new NoDeviceConnectedException();
}
Socket sock = targetDevice.Connect(port, usbmuxAddress: usbmuxAddress, logger);
- return new ServiceConnection(sock, logger, targetDevice);
+ return new ServiceConnection(sock, logger ?? NullLogger.Instance, targetDevice);
}
private bool UserCertificateValidationCallback(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors)
@@ -83,22 +87,22 @@ private bool UserCertificateValidationCallback(object sender, X509Certificate? c
public void Close()
{
- networkStream.Close();
+ Stream.Close();
}
public void Dispose()
{
Close();
- networkStream.Dispose();
+ Stream.Dispose();
GC.SuppressFinalize(this);
}
public byte[] Receive(int length = 4096)
{
if (length <= 0) {
- return Array.Empty();
+ return [];
}
- List buffer = new List();
+ byte[] buffer = new byte[length];
int totalBytesRead = 0;
while (totalBytesRead < length) {
@@ -108,24 +112,26 @@ public byte[] Receive(int length = 4096)
readSize = MAX_READ_SIZE;
}
- int bytesRead = networkStream.Read(receiveBuffer, 0, readSize);
- if (bytesRead == 0) { // If we don't get any bytes, the network connection was broken
+ int bytesRead = Stream.Read(buffer, totalBytesRead, readSize);
+ if (bytesRead == 0) {
+ _logger.LogError("Read zero bytes so the connection has been broken");
break;
}
totalBytesRead += bytesRead;
-
- buffer.AddRange(receiveBuffer.Take(bytesRead));
}
- return buffer.ToArray();
+ if (totalBytesRead < buffer.Length) {
+ return buffer.Take(totalBytesRead).ToArray();
+ }
+ return buffer;
}
public async Task ReceiveAsync(int length, CancellationToken cancellationToken)
{
if (length <= 0) {
- return Array.Empty();
+ return [];
}
- List buffer = new List();
+ byte[] buffer = new byte[length];
int totalBytesRead = 0;
while (totalBytesRead < length) {
@@ -136,13 +142,13 @@ public async Task ReceiveAsync(int length, CancellationToken cancellatio
}
int bytesRead;
- if (networkStream.ReadTimeout != -1) {
+ if (Stream.ReadTimeout != -1) {
CancellationTokenSource localTaskComplete = new CancellationTokenSource();
- Task result = networkStream.ReadAsync(receiveBuffer, 0, readSize, localTaskComplete.Token);
- Task delay = Task.Delay(networkStream.ReadTimeout, localTaskComplete.Token);
+ Task result = Stream.ReadAsync(buffer, totalBytesRead, readSize, localTaskComplete.Token);
+ Task delay = Task.Delay(Stream.ReadTimeout, localTaskComplete.Token);
- await Task.WhenAny(result, delay).WaitAsync(cancellationToken);
+ await Task.WhenAny(result, delay).WaitAsync(cancellationToken).ConfigureAwait(false);
if (cancellationToken.IsCancellationRequested) {
localTaskComplete.Cancel();
}
@@ -150,20 +156,23 @@ public async Task ReceiveAsync(int length, CancellationToken cancellatio
localTaskComplete.Cancel();
throw new TimeoutException("Timeout waiting for message from service");
}
- bytesRead = await result;
- if (bytesRead == 0) { // If we don't get any bytes, the network connection was broken
+ bytesRead = await result.ConfigureAwait(false);
+ if (bytesRead == 0) {
+ _logger.LogError("Read zero bytes so the connection has been broken");
break;
}
}
else {
- bytesRead = await networkStream.ReadAsync(receiveBuffer.AsMemory(0, readSize), cancellationToken);
+ bytesRead = await Stream.ReadAsync(buffer.AsMemory(totalBytesRead, readSize), cancellationToken).ConfigureAwait(false);
}
totalBytesRead += bytesRead;
- buffer.AddRange(receiveBuffer.Take(bytesRead));
}
- return buffer.ToArray();
+ if (totalBytesRead < buffer.Length) {
+ return buffer.Take(totalBytesRead).ToArray();
+ }
+ return buffer;
}
public PropertyNode? ReceivePlist()
@@ -177,11 +186,11 @@ public async Task ReceiveAsync(int length, CancellationToken cancellatio
public async Task ReceivePlistAsync(CancellationToken cancellationToken)
{
- byte[] plistBytes = await ReceivePrefixedAsync(cancellationToken);
+ byte[] plistBytes = await ReceivePrefixedAsync(cancellationToken).ConfigureAwait(false);
if (plistBytes.Length == 0) {
return null;
}
- return await PropertyList.LoadFromByteArrayAsync(plistBytes);
+ return await PropertyList.LoadFromByteArrayAsync(plistBytes).ConfigureAwait(false);
}
///
@@ -192,7 +201,7 @@ public byte[] ReceivePrefixed()
{
byte[] sizeBytes = Receive(4);
if (sizeBytes.Length != 4) {
- return Array.Empty();
+ return [];
}
int size = EndianBitConverter.BigEndian.ToInt32(sizeBytes, 0);
@@ -205,23 +214,23 @@ public byte[] ReceivePrefixed()
/// The data without the u32 field length as a byte array
public async Task ReceivePrefixedAsync(CancellationToken cancellationToken = default)
{
- byte[] sizeBytes = await ReceiveAsync(4, cancellationToken);
+ byte[] sizeBytes = await ReceiveAsync(4, cancellationToken).ConfigureAwait(false);
if (sizeBytes.Length != 4) {
- return Array.Empty();
+ return [];
}
int size = EndianBitConverter.BigEndian.ToInt32(sizeBytes, 0);
- return await ReceiveAsync(size, cancellationToken);
+ return await ReceiveAsync(size, cancellationToken).ConfigureAwait(false);
}
public void Send(byte[] data)
{
- networkStream.Write(data);
+ Stream.Write(data);
}
public async Task SendAsync(byte[] data, CancellationToken cancellationToken)
{
- await networkStream.WriteAsync(data, cancellationToken);
+ await Stream.WriteAsync(data, cancellationToken).ConfigureAwait(false);
}
public void SendPlist(PropertyNode data, PlistFormat format = PlistFormat.Xml)
@@ -229,10 +238,8 @@ public void SendPlist(PropertyNode data, PlistFormat format = PlistFormat.Xml)
byte[] plistBytes = PropertyList.SaveAsByteArray(data, format);
byte[] lengthBytes = BitConverter.GetBytes(EndianBitConverter.BigEndian.ToInt32(BitConverter.GetBytes(plistBytes.Length), 0));
- List payload = new List();
- payload.AddRange(lengthBytes);
- payload.AddRange(plistBytes);
- Send(payload.ToArray());
+ Send(lengthBytes);
+ Send(plistBytes);
}
public async Task SendPlistAsync(PropertyNode data, PlistFormat format = PlistFormat.Xml, CancellationToken cancellationToken = default)
@@ -240,10 +247,8 @@ public async Task SendPlistAsync(PropertyNode data, PlistFormat format = PlistFo
byte[] plistBytes = PropertyList.SaveAsByteArray(data, format);
byte[] lengthBytes = BitConverter.GetBytes(EndianBitConverter.BigEndian.ToInt32(BitConverter.GetBytes(plistBytes.Length), 0));
- List payload = new List();
- payload.AddRange(lengthBytes);
- payload.AddRange(plistBytes);
- await SendAsync(payload.ToArray(), cancellationToken);
+ await SendAsync(lengthBytes, cancellationToken).ConfigureAwait(false);
+ await SendAsync(plistBytes, cancellationToken).ConfigureAwait(false);
}
public PropertyNode? SendReceivePlist(PropertyNode data)
@@ -254,8 +259,8 @@ public async Task SendPlistAsync(PropertyNode data, PlistFormat format = PlistFo
public async Task SendReceivePlistAsync(PropertyNode data, CancellationToken cancellationToken)
{
- await SendPlistAsync(data, cancellationToken: cancellationToken);
- return await ReceivePlistAsync(cancellationToken);
+ await SendPlistAsync(data, cancellationToken: cancellationToken).ConfigureAwait(false);
+ return await ReceivePlistAsync(cancellationToken).ConfigureAwait(false);
}
///
@@ -264,8 +269,8 @@ public async Task SendPlistAsync(PropertyNode data, PlistFormat format = PlistFo
/// A value in milliseconds that detemines how long the service connection will wait before timing out
public void SetTimeout(int timeout = -1)
{
- networkStream.ReadTimeout = timeout;
- networkStream.WriteTimeout = timeout;
+ Stream.ReadTimeout = timeout;
+ Stream.WriteTimeout = timeout;
}
public void StartSSL(byte[] certData, byte[] privateKeyData)
@@ -274,19 +279,20 @@ public void StartSSL(byte[] certData, byte[] privateKeyData)
string privateKeyText = Encoding.UTF8.GetString(privateKeyData);
X509Certificate2 cert = X509Certificate2.CreateFromPem(certText, privateKeyText);
- networkStream.Flush();
+ if (_networkStream == null) {
+ throw new InvalidOperationException("Network stream is null");
+ }
+ _networkStream.Flush();
- SslStream sslStream = new SslStream(networkStream, true, UserCertificateValidationCallback, null, EncryptionPolicy.RequireEncryption);
+ _sslStream = new SslStream(_networkStream, true, UserCertificateValidationCallback, null, EncryptionPolicy.RequireEncryption);
try {
// NOTE: For some reason we need to re-export and then import the cert again ¯\_(ツ)_/¯
// see this for more details: https://github.com/dotnet/runtime/issues/45680
- sslStream.AuthenticateAsClient(string.Empty, [new X509Certificate2(cert.Export(X509ContentType.Pkcs12))], SslProtocols.None, false);
+ _sslStream.AuthenticateAsClient(string.Empty, [new X509Certificate2(cert.Export(X509ContentType.Pkcs12))], SslProtocols.None, false);
}
catch (AuthenticationException ex) {
- logger.LogError(ex, "SSL authentication failed");
+ _logger.LogError(ex, "SSL authentication failed");
}
-
- networkStream = sslStream;
}
}
}