Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ServiceConnection Connection Detection #69

Merged
merged 3 commits into from
Jan 8, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 65 additions & 59 deletions Netimobiledevice/Lockdown/ServiceConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using Netimobiledevice.Plist;
using Netimobiledevice.Usbmuxd;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
Expand All @@ -24,33 +23,38 @@ namespace Netimobiledevice.Lockdown
/// </summary>
public class ServiceConnection : IDisposable
{
private const int MAX_READ_SIZE = 4096;
private const int MAX_READ_SIZE = 32768;

/// <summary>
/// The internal logger
/// </summary>
private readonly ILogger logger;
private Stream networkStream;
private readonly byte[] receiveBuffer = new byte[MAX_READ_SIZE];
private readonly ILogger _logger;
/// <summary>
/// The initial stream used for the ServiceConnection until the SSL stream starts, unless you specifically need to use this stream
/// you should use the Stream property instead
/// </summary>
private readonly NetworkStream _networkStream;
/// <summary>
/// The main stream once SSL is established, unless you specifically need to use this stream you should use the Stream
/// property instead
/// </summary>
private SslStream? _sslStream;

public UsbmuxdDevice? MuxDevice { get; private set; }

public bool IsConnected {
get {
if (networkStream is NetworkStream ns) {
return ns.Socket.Connected;
}
return false;
return _networkStream.Socket.Connected;
}
}

public Stream Stream => networkStream;
public Stream Stream => _sslStream != null ? _sslStream : _networkStream;

private ServiceConnection(Socket sock, ILogger logger, UsbmuxdDevice? muxDevice = null)
{
this.logger = logger;
_logger = logger;
_networkStream = new NetworkStream(sock, true);

networkStream = new NetworkStream(sock, true);
// Usbmux connections contain additional information associated with the current connection
MuxDevice = muxDevice;
}
Expand All @@ -73,7 +77,7 @@ internal static ServiceConnection CreateUsingUsbmux(string udid, ushort port, Us
throw new NoDeviceConnectedException();
}
Socket sock = targetDevice.Connect(port, usbmuxAddress: usbmuxAddress, logger);
return new ServiceConnection(sock, logger, targetDevice);
return new ServiceConnection(sock, logger ?? NullLogger.Instance, targetDevice);
}

private bool UserCertificateValidationCallback(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors)
Expand All @@ -83,22 +87,22 @@ private bool UserCertificateValidationCallback(object sender, X509Certificate? c

public void Close()
{
networkStream.Close();
Stream.Close();
}

public void Dispose()
{
Close();
networkStream.Dispose();
Stream.Dispose();
GC.SuppressFinalize(this);
}

public byte[] Receive(int length = 4096)
{
if (length <= 0) {
return Array.Empty<byte>();
return [];
}
List<byte> buffer = new List<byte>();
byte[] buffer = new byte[length];

int totalBytesRead = 0;
while (totalBytesRead < length) {
Expand All @@ -108,24 +112,26 @@ public byte[] Receive(int length = 4096)
readSize = MAX_READ_SIZE;
}

int bytesRead = networkStream.Read(receiveBuffer, 0, readSize);
if (bytesRead == 0) { // If we don't get any bytes, the network connection was broken
int bytesRead = Stream.Read(buffer, totalBytesRead, readSize);
if (bytesRead == 0) {
_logger.LogError("Read zero bytes so the connection has been broken");
break;
}
totalBytesRead += bytesRead;

buffer.AddRange(receiveBuffer.Take(bytesRead));
}

return buffer.ToArray();
if (totalBytesRead < buffer.Length) {
return buffer.Take(totalBytesRead).ToArray();
}
return buffer;
}

public async Task<byte[]> ReceiveAsync(int length, CancellationToken cancellationToken)
{
if (length <= 0) {
return Array.Empty<byte>();
return [];
}
List<byte> buffer = new List<byte>();
byte[] buffer = new byte[length];

int totalBytesRead = 0;
while (totalBytesRead < length) {
Expand All @@ -136,34 +142,37 @@ public async Task<byte[]> ReceiveAsync(int length, CancellationToken cancellatio
}

int bytesRead;
if (networkStream.ReadTimeout != -1) {
if (Stream.ReadTimeout != -1) {
CancellationTokenSource localTaskComplete = new CancellationTokenSource();

Task<int> result = networkStream.ReadAsync(receiveBuffer, 0, readSize, localTaskComplete.Token);
Task delay = Task.Delay(networkStream.ReadTimeout, localTaskComplete.Token);
Task<int> result = Stream.ReadAsync(buffer, totalBytesRead, readSize, localTaskComplete.Token);
Task delay = Task.Delay(Stream.ReadTimeout, localTaskComplete.Token);

await Task.WhenAny(result, delay).WaitAsync(cancellationToken);
await Task.WhenAny(result, delay).WaitAsync(cancellationToken).ConfigureAwait(false);
if (cancellationToken.IsCancellationRequested) {
localTaskComplete.Cancel();
}
else if (!result.IsCompleted) {
localTaskComplete.Cancel();
throw new TimeoutException("Timeout waiting for message from service");
}
bytesRead = await result;
if (bytesRead == 0) { // If we don't get any bytes, the network connection was broken
bytesRead = await result.ConfigureAwait(false);
if (bytesRead == 0) {
_logger.LogError("Read zero bytes so the connection has been broken");
break;
}
}
else {
bytesRead = await networkStream.ReadAsync(receiveBuffer.AsMemory(0, readSize), cancellationToken);
bytesRead = await Stream.ReadAsync(buffer.AsMemory(totalBytesRead, readSize), cancellationToken).ConfigureAwait(false);
}

totalBytesRead += bytesRead;
buffer.AddRange(receiveBuffer.Take(bytesRead));
}

return buffer.ToArray();
if (totalBytesRead < buffer.Length) {
return buffer.Take(totalBytesRead).ToArray();
}
return buffer;
}

public PropertyNode? ReceivePlist()
Expand All @@ -177,11 +186,11 @@ public async Task<byte[]> ReceiveAsync(int length, CancellationToken cancellatio

public async Task<PropertyNode?> ReceivePlistAsync(CancellationToken cancellationToken)
{
byte[] plistBytes = await ReceivePrefixedAsync(cancellationToken);
byte[] plistBytes = await ReceivePrefixedAsync(cancellationToken).ConfigureAwait(false);
if (plistBytes.Length == 0) {
return null;
}
return await PropertyList.LoadFromByteArrayAsync(plistBytes);
return await PropertyList.LoadFromByteArrayAsync(plistBytes).ConfigureAwait(false);
}

/// <summary>
Expand All @@ -192,7 +201,7 @@ public byte[] ReceivePrefixed()
{
byte[] sizeBytes = Receive(4);
if (sizeBytes.Length != 4) {
return Array.Empty<byte>();
return [];
}

int size = EndianBitConverter.BigEndian.ToInt32(sizeBytes, 0);
Expand All @@ -205,45 +214,41 @@ public byte[] ReceivePrefixed()
/// <returns>The data without the u32 field length as a byte array</returns>
public async Task<byte[]> ReceivePrefixedAsync(CancellationToken cancellationToken = default)
{
byte[] sizeBytes = await ReceiveAsync(4, cancellationToken);
byte[] sizeBytes = await ReceiveAsync(4, cancellationToken).ConfigureAwait(false);
if (sizeBytes.Length != 4) {
return Array.Empty<byte>();
return [];
}

int size = EndianBitConverter.BigEndian.ToInt32(sizeBytes, 0);
return await ReceiveAsync(size, cancellationToken);
return await ReceiveAsync(size, cancellationToken).ConfigureAwait(false);
}

public void Send(byte[] data)
{
networkStream.Write(data);
Stream.Write(data);
}

public async Task SendAsync(byte[] data, CancellationToken cancellationToken)
{
await networkStream.WriteAsync(data, cancellationToken);
await Stream.WriteAsync(data, cancellationToken).ConfigureAwait(false);
}

public void SendPlist(PropertyNode data, PlistFormat format = PlistFormat.Xml)
{
byte[] plistBytes = PropertyList.SaveAsByteArray(data, format);
byte[] lengthBytes = BitConverter.GetBytes(EndianBitConverter.BigEndian.ToInt32(BitConverter.GetBytes(plistBytes.Length), 0));

List<byte> payload = new List<byte>();
payload.AddRange(lengthBytes);
payload.AddRange(plistBytes);
Send(payload.ToArray());
Send(lengthBytes);
Send(plistBytes);
}

public async Task SendPlistAsync(PropertyNode data, PlistFormat format = PlistFormat.Xml, CancellationToken cancellationToken = default)
{
byte[] plistBytes = PropertyList.SaveAsByteArray(data, format);
byte[] lengthBytes = BitConverter.GetBytes(EndianBitConverter.BigEndian.ToInt32(BitConverter.GetBytes(plistBytes.Length), 0));

List<byte> payload = new List<byte>();
payload.AddRange(lengthBytes);
payload.AddRange(plistBytes);
await SendAsync(payload.ToArray(), cancellationToken);
await SendAsync(lengthBytes, cancellationToken).ConfigureAwait(false);
await SendAsync(plistBytes, cancellationToken).ConfigureAwait(false);
}

public PropertyNode? SendReceivePlist(PropertyNode data)
Expand All @@ -254,8 +259,8 @@ public async Task SendPlistAsync(PropertyNode data, PlistFormat format = PlistFo

public async Task<PropertyNode?> SendReceivePlistAsync(PropertyNode data, CancellationToken cancellationToken)
{
await SendPlistAsync(data, cancellationToken: cancellationToken);
return await ReceivePlistAsync(cancellationToken);
await SendPlistAsync(data, cancellationToken: cancellationToken).ConfigureAwait(false);
return await ReceivePlistAsync(cancellationToken).ConfigureAwait(false);
}

/// <summary>
Expand All @@ -264,8 +269,8 @@ public async Task SendPlistAsync(PropertyNode data, PlistFormat format = PlistFo
/// <param name="timeout">A value in milliseconds that detemines how long the service connection will wait before timing out</param>
public void SetTimeout(int timeout = -1)
{
networkStream.ReadTimeout = timeout;
networkStream.WriteTimeout = timeout;
Stream.ReadTimeout = timeout;
Stream.WriteTimeout = timeout;
}

public void StartSSL(byte[] certData, byte[] privateKeyData)
Expand All @@ -274,19 +279,20 @@ public void StartSSL(byte[] certData, byte[] privateKeyData)
string privateKeyText = Encoding.UTF8.GetString(privateKeyData);
X509Certificate2 cert = X509Certificate2.CreateFromPem(certText, privateKeyText);

networkStream.Flush();
if (_networkStream == null) {
throw new InvalidOperationException("Network stream is null");
}
_networkStream.Flush();

SslStream sslStream = new SslStream(networkStream, true, UserCertificateValidationCallback, null, EncryptionPolicy.RequireEncryption);
_sslStream = new SslStream(_networkStream, true, UserCertificateValidationCallback, null, EncryptionPolicy.RequireEncryption);
try {
// NOTE: For some reason we need to re-export and then import the cert again ¯\_(ツ)_/¯
// see this for more details: https://github.com/dotnet/runtime/issues/45680
sslStream.AuthenticateAsClient(string.Empty, [new X509Certificate2(cert.Export(X509ContentType.Pkcs12))], SslProtocols.None, false);
_sslStream.AuthenticateAsClient(string.Empty, [new X509Certificate2(cert.Export(X509ContentType.Pkcs12))], SslProtocols.None, false);
}
catch (AuthenticationException ex) {
logger.LogError(ex, "SSL authentication failed");
_logger.LogError(ex, "SSL authentication failed");
}

networkStream = sslStream;
}
}
}
Loading