From 4e767c858df7ceec6c366b0b0d85d2f1776ba51e Mon Sep 17 00:00:00 2001 From: Tides Date: Thu, 30 Jan 2025 07:38:38 -0600 Subject: [PATCH 1/3] Port ModifiedUtf8 from sebs nbt rework --- Obsidian.Nbt/ModifiedUtf8.cs | 466 ++++++++++++++++++++++++++ Obsidian.Nbt/NbtReader.Primitives.cs | 2 +- Obsidian.Nbt/NbtWriter.Primitives.cs | 4 +- Obsidian.Nbt/Utilities/ThrowHelper.cs | 47 +++ 4 files changed, 517 insertions(+), 2 deletions(-) create mode 100644 Obsidian.Nbt/ModifiedUtf8.cs create mode 100644 Obsidian.Nbt/Utilities/ThrowHelper.cs diff --git a/Obsidian.Nbt/ModifiedUtf8.cs b/Obsidian.Nbt/ModifiedUtf8.cs new file mode 100644 index 00000000..93e80978 --- /dev/null +++ b/Obsidian.Nbt/ModifiedUtf8.cs @@ -0,0 +1,466 @@ +using Obsidian.Nbt.Utilities; +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; + +namespace Obsidian.Nbt; + +// Encoding specification: https://web.archive.org/web/20211117120323/https://docs.oracle.com/javase/8/docs/api/java/io/DataInput.html +// +-----------------------------+-------------------------------------------------------------------+ +// | Character range | Bit values | +// +-----------------------------+-------------------------------------------------------------------+ +// | \u0001 to \u007F | 0 | bits 6-0 |......................................| +// | \u0080 to \u07FF and \u0000 | 1 | 1 | 0 | bits 10-6 | 1 | 0 | bits 5-0 |..................| +// | \u0800 to \uFFFF | 1 | 1 | 1 | 0 | bits 15-12 | 1 | 0 | bits 11-6 | 1 | 0 | bits 5-0 | +// +-----------------------------+-------------------------------------------------------------------+ + +/// +/// Provides methods for working with a modification of UTF-8 encoding used by the NBT format. +/// +public static class ModifiedUtf8 +{ + /// + /// Encodes a span of characters into an array of bytes. + /// + /// The span of characters to encode. + /// An array of bytes containing the results of encoding the specified sequence of characters. + /// true if encoding the characters was successful; otherwise false + public static bool TryGetBytes(ReadOnlySpan chars, [NotNullWhen(true)] out byte[]? bytes) + { + if (!TryGetByteCount(chars, out int byteCount)) + { + bytes = null; + return false; + } + + if (byteCount == 0) + { + bytes = []; + return true; + } + + bytes = GC.AllocateUninitializedArray(byteCount); + GetBytesCommon(chars, bytes); + return true; + } + + /// + /// Encodes a span of characters into a sequence of bytes obtained from a buffer writer. + /// + /// The span of characters to encode. + /// Bytes output sink. + /// true if encoding the characters was successful; otherwise false + public static bool TryGetBytes(ReadOnlySpan chars, IBufferWriter bufferWriter) + { + if (!TryGetByteCount(chars, out int byteCount)) + return false; + + Span sink = bufferWriter.GetSpan(byteCount); + if (sink.Length < byteCount) + ThrowHelper.ThrowException_InsufficientBufferSize(); + + GetBytesCommon(chars, sink); + bufferWriter.Advance(byteCount); + return true; + } + + internal static void GetBytesCommon(ReadOnlySpan chars, Span bytes) + { + if (chars.Length == bytes.Length) + { + if (BitConverter.IsLittleEndian && Sse2.IsSupported && bytes.Length >= 32) + { + GetBytesAsciiSse2(chars, bytes); + } + else + { + GetBytesAsciiScalar(chars, bytes); + } + } + else + { + GetBytesScalar(chars, bytes); + } + } + + private static void GetBytesScalar(ReadOnlySpan chars, Span bytes) + { + ref byte destination = ref MemoryMarshal.GetReference(bytes); + for (int i = 0; i < chars.Length; i++) + { + char c = chars[i]; + if (c < 0x80 && c != 0) + { + destination = (byte)c; + } + else if (c >= 0x800) + { + destination = (byte)(0xE0 | ((c >> 12) & 0x0F)); + destination = ref Unsafe.Add(ref destination, 1); + destination = (byte)(0x80 | ((c >> 6) & 0x3F)); + destination = ref Unsafe.Add(ref destination, 1); + destination = (byte)(0x80 | ((c >> 0) & 0x3F)); + } + else + { + destination = (byte)(0xC0 | ((c >> 6) & 0x1F)); + destination = ref Unsafe.Add(ref destination, 1); + destination = (byte)(0x80 | ((c >> 0) & 0x3F)); + } + destination = ref Unsafe.Add(ref destination, 1); + } + } + + private static void GetBytesAsciiSse2(ReadOnlySpan chars, Span bytes) + { + Debug.Assert(BitConverter.IsLittleEndian); + Debug.Assert(Sse2.IsSupported); + + ref byte destination = ref MemoryMarshal.GetReference(bytes); + ref byte source = ref Unsafe.As(ref MemoryMarshal.GetReference(chars)); + ref byte sourceEnd = ref Unsafe.Add(ref source, chars.Length * sizeof(char) - 31); + + while (Unsafe.IsAddressLessThan(ref source, ref sourceEnd)) + { + Vector128 first = Unsafe.As>(ref source); + source = ref Unsafe.Add(ref source, 16); + Vector128 second = Unsafe.As>(ref source); + source = ref Unsafe.Add(ref source, 16); + + Vector128 packed = Sse2.PackUnsignedSaturate(first, second); + + Unsafe.WriteUnaligned(ref destination, packed); + destination = ref Unsafe.Add(ref destination, 16); + } + + sourceEnd = ref Unsafe.Add(ref sourceEnd, 31); + while (Unsafe.IsAddressLessThan(ref source, ref sourceEnd)) + { + destination = source; + + destination = ref Unsafe.Add(ref destination, 1); + source = ref Unsafe.Add(ref source, 2); + } + } + + private static void GetBytesAsciiScalar(ReadOnlySpan chars, Span bytes) + { + ref byte destination = ref MemoryMarshal.GetReference(bytes); + for (int i = 0; i < chars.Length; i++) + { + destination = (byte)chars[i]; + destination = ref Unsafe.Add(ref destination, 1); + } + } + + /// + /// Decodes a span of bytes into a string. + /// + /// The span of bytes to decode. + /// A containing the results of decoding the specified sequence of bytes. + /// Sequence contained incorrectly formatted bytes. + public static string GetString(ReadOnlySpan bytes) + { + if (TryGetString(bytes, out string? @string)) + { + return @string; + } + + throw new FormatException("Input data contained invalid bytes."); + } + + /// + /// Decodes a span of bytes into a string and returns a value indicating whether the conversion was successfull. + /// + /// The span of bytes to decode. + /// A containing the results of decoding the specified sequence of bytes. + /// true if decoding the bytes was successful; otherwise false + public static bool TryGetString(ReadOnlySpan bytes, [NotNullWhen(true)] out string? @string) + { + if (bytes.IsEmpty) + { + @string = string.Empty; + return true; + } + + if (TryGetCharCount(bytes, out int length)) + { + @string = new string('\0', length); + ref char stringRef = ref Unsafe.AsRef(in @string.GetPinnableReference()); + GetStringCommon(bytes, length, ref stringRef); + return true; + } + + @string = null; + return false; + } + + private static bool TryGetCharCount(ReadOnlySpan bytes, out int charCount) + { + Debug.Assert(!bytes.IsEmpty); + + charCount = 0; + + if (bytes.Length > ushort.MaxValue) + { + return false; + } + + // Make sure that the last byte(s) is not partial + int last = bytes[^1] >> 6; + if (last > 2) // First byte of a byte group + { + return false; + } + else if (last == 2) // Part of a byte group + { + if (bytes.Length < 2) + return false; + + last = bytes[^2] >> 5; + if (last == 0b100 || last == 0b101) // Three byte group + { + if (bytes.Length < 3) + return false; + + last = bytes[^3] >> 4; + if (last != 0b1110) + return false; + } + else if (last != 0b110) // Two byte group + { + return false; + } + } + + ref byte @ref = ref MemoryMarshal.GetReference(bytes); + ref byte end = ref Unsafe.Add(ref @ref, bytes.Length); + while (Unsafe.IsAddressLessThan(ref @ref, ref end)) + { + int header = @ref >> 4; + if (header < 0b1000) // One byte + { + } + else if (header < 0b1110) // Two bytes + { + @ref = ref Unsafe.Add(ref @ref, 1); + if ((@ref >> 6) != 0b10) + return false; + } + else if (header == 0b1110) // Three bytes + { + @ref = ref Unsafe.Add(ref @ref, 1); + if ((@ref >> 6) != 0b10) + return false; + @ref = ref Unsafe.Add(ref @ref, 1); + if ((@ref >> 6) != 0b10) + return false; + } + else // Invalid header + { + return false; + } + charCount++; + @ref = ref Unsafe.Add(ref @ref, 1); + } + + return true; + } + + private static void GetStringCommon(ReadOnlySpan bytes, int stringLength, ref char destination) + { + if (bytes.Length == stringLength) + { + if (!BitConverter.IsLittleEndian || stringLength < 16) + { + GetStringAsciiScalar(bytes, ref destination); + } + else + { + GetStringAsciiAvx2(bytes, ref destination); + } + } + else + { + GetStringScalar(bytes, ref destination); + } + } + + private static void GetStringAsciiAvx2(ReadOnlySpan bytes, ref char destination) + { + Debug.Assert(BitConverter.IsLittleEndian); + Debug.Assert(Avx2.IsSupported); + + ref byte target = ref Unsafe.As(ref destination); + ref byte source = ref MemoryMarshal.GetReference(bytes); + ref byte sourceEnd = ref Unsafe.Add(ref source, bytes.Length - 31); + while (Unsafe.IsAddressLessThan(ref source, ref sourceEnd)) + { + Vector256 vector = Unsafe.As>(ref source); + + Vector256 low = Avx2.UnpackLow(vector, Vector256.Zero); + Unsafe.WriteUnaligned(ref target, low); + target = ref Unsafe.Add(ref target, 32); + + Vector256 high = Avx2.UnpackHigh(vector, Vector256.Zero); + Unsafe.WriteUnaligned(ref target, high); + target = ref Unsafe.Add(ref target, 32); + + source = ref Unsafe.Add(ref source, 32); + } + + sourceEnd = ref Unsafe.Add(ref sourceEnd, 31); + destination = ref Unsafe.As(ref target); + while (Unsafe.IsAddressLessThan(ref source, ref sourceEnd)) + { + destination = (char)source; + + source = ref Unsafe.Add(ref source, 1); + destination = ref Unsafe.Add(ref destination, 1); + } + } + + private static void GetStringAsciiScalar(ReadOnlySpan bytes, ref char destination) + { + for (int i = 0; i < bytes.Length; i++) + { + destination = (char)bytes[i]; + destination = ref Unsafe.Add(ref destination, 1); + } + } + + private static void GetStringScalar(ReadOnlySpan bytes, ref char destination) + { + ref byte source = ref MemoryMarshal.GetReference(bytes); + ref byte sourceEnd = ref Unsafe.Add(ref source, bytes.Length); + while (Unsafe.IsAddressLessThan(ref source, ref sourceEnd)) + { + int c = source; + int header = c >> 4; + if (header < 0b1000) // One byte + { + destination = (char)c; + } + else if (header < 0b1110) // Two bytes + { + c = (c & 0b0001_1111) << 6; + + source = ref Unsafe.Add(ref source, 1); + c |= source & 0b0011_1111; + + destination = (char)c; + } + else // Three bytes + { + c = (c & 0b0000_1111) << 12; + + source = ref Unsafe.Add(ref source, 1); + c |= (source & 0b0011_1111) << 6; + + source = ref Unsafe.Add(ref source, 1); + c |= source & 0b0011_1111; + + destination = (char)c; + } + source = ref Unsafe.Add(ref source, 1); + destination = ref Unsafe.Add(ref destination, 1); + } + } + + /// + /// Calculates the number of bytes produced by encoding the specified character span, unless the number of bytes is more than the encoding supports. + /// + /// The span that contains the set of characters to encode. + /// The number of bytes produced by encoding the specified character span. + /// true if calculating the number of bytes was successful; otherwise false + public static bool TryGetByteCount(ReadOnlySpan chars, out int byteCount) + { + if (chars.Length > ushort.MaxValue) // Even for all-ASCII inputs, this produces > 65535 bytes + { + byteCount = default; + return false; + } + + if (chars.IsEmpty) + { + byteCount = 0; + return true; + } + + if (Avx2.IsSupported) + { + byteCount = GetByteCountAvx2(chars); + return byteCount <= ushort.MaxValue; + } + else + { + byteCount = GetByteCountScalar(chars); + return byteCount <= ushort.MaxValue; + } + } + + private static int GetByteCountAvx2(ReadOnlySpan chars) + { + Debug.Assert(Avx2.IsSupported); + Debug.Assert(chars.Length <= ushort.MaxValue); // Ensure that the counter can't overflow + + const short TwoBytesBorder = (0x0080 >> 1) - 1; + const short ThreeBytesBorder = (0x0800 >> 1) - 1; + + int byteCount = chars.Length; // Count all characters that will produce at least one byte + + ref byte ptr = ref Unsafe.As(ref MemoryMarshal.GetReference(chars)); + ref byte ptrEnd = ref Unsafe.Add(ref ptr, chars.Length * sizeof(char) - 31); + + Vector256 counter = Vector256.Zero; + for (; Unsafe.IsAddressLessThan(ref ptr, ref ptrEnd); ptr = ref Unsafe.Add(ref ptr, 32)) + { + // Transform all 0x0000s to 0x0080s + Vector256 vustr = Unsafe.ReadUnaligned>(ref ptr); + Vector256 mask = Avx2.CompareEqual(vustr, Vector256.Zero); + Vector256 blend = Avx2.BlendVariable(vustr, Vector256.Create((ushort)0x0080), mask); + + Vector256 vstr = Avx2.ShiftRightLogical(blend, 1).AsInt16(); + counter = Avx2.Subtract(counter, Avx2.CompareGreaterThan(vstr, Vector256.Create(TwoBytesBorder))); // Count all characters that will produce at least two bytes + counter = Avx2.Subtract(counter, Avx2.CompareGreaterThan(vstr, Vector256.Create(ThreeBytesBorder))); // Count all characters that will produce three bytes + } + + counter = Avx2.HorizontalAdd(counter, counter); // 128 + counter = Avx2.HorizontalAdd(counter, counter); // 64 + + // Here we must add by hand to avoid overflowing int16 + // Indexes of summing results are based on https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_hadd_epi16&ig_expand=3839 + ref short counterRef = ref Unsafe.As, short>(ref counter); + byteCount += counterRef; + byteCount += Unsafe.Add(ref counterRef, 1); + byteCount += Unsafe.Add(ref counterRef, 8); + byteCount += Unsafe.Add(ref counterRef, 9); + + // Count the rest as scalars + ptrEnd = ref Unsafe.Add(ref ptrEnd, 31); + while (Unsafe.IsAddressLessThan(ref ptr, ref ptrEnd)) + { + ushort c = Unsafe.ReadUnaligned(ref ptr); + if (c >= 0x80 || c == 0) + byteCount += (c >= 0x800) ? 2 : 1; + ptr = ref Unsafe.Add(ref ptr, sizeof(ushort)); + } + return byteCount; + } + + private static int GetByteCountScalar(ReadOnlySpan chars) + { + int byteCount = chars.Length; + for (int i = 0; i < chars.Length; i++) + { + char c = chars[i]; + if (c >= 0x80 || c == 0) + byteCount += (c >= 0x800) ? 2 : 1; + } + return byteCount; + } +} diff --git a/Obsidian.Nbt/NbtReader.Primitives.cs b/Obsidian.Nbt/NbtReader.Primitives.cs index bcf945e7..d57c7fa0 100644 --- a/Obsidian.Nbt/NbtReader.Primitives.cs +++ b/Obsidian.Nbt/NbtReader.Primitives.cs @@ -17,7 +17,7 @@ public string ReadString() this.BaseStream.ReadExactly(buffer); - return Encoding.UTF8.GetString(buffer); + return ModifiedUtf8.GetString(buffer); } public short ReadInt16() diff --git a/Obsidian.Nbt/NbtWriter.Primitives.cs b/Obsidian.Nbt/NbtWriter.Primitives.cs index 2254b44a..d64650c9 100644 --- a/Obsidian.Nbt/NbtWriter.Primitives.cs +++ b/Obsidian.Nbt/NbtWriter.Primitives.cs @@ -1,3 +1,4 @@ +using Obsidian.Nbt.Utilities; using System.Buffers.Binary; using System.Text; @@ -137,7 +138,8 @@ internal void WriteStringInternal(string value) if (value.Length > short.MaxValue) throw new InvalidOperationException($"value length must be less than {short.MaxValue}"); - var buffer = Encoding.UTF8.GetBytes(value); + if (!ModifiedUtf8.TryGetBytes(value, out var buffer)) + throw new InvalidOperationException("Failed to get bytes from string."); this.WriteShortInternal((short)buffer.Length); this.BaseStream.Write(buffer); diff --git a/Obsidian.Nbt/Utilities/ThrowHelper.cs b/Obsidian.Nbt/Utilities/ThrowHelper.cs new file mode 100644 index 00000000..a0106243 --- /dev/null +++ b/Obsidian.Nbt/Utilities/ThrowHelper.cs @@ -0,0 +1,47 @@ +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +namespace Obsidian.Nbt.Utilities; + +internal class ThrowHelper +{ + internal static void ThrowInvalidOperationException_StringTooLong() + { + throw new InvalidOperationException("Received string is longer than allowed."); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static void ThrowOutOfRangeException_IfNegative(int value, [CallerArgumentExpression("value")] string? paramName = null) + { + if (value < 0) + { + ThrowOutOfRangeException_Negative(value, paramName!); + } + } + + internal static void ThrowOutOfRangeException_Negative(int value, string paramName) + { + throw new ArgumentOutOfRangeException($"Value of {paramName} must be positive or zero."); + } + + [DoesNotReturn] + internal static void ThrowInvalidOperationException_NotEnoughData() + { + throw new InvalidOperationException("There isn't enough buffered data for this operation."); + } + + internal static void ThrowInvalidOperationException_InvalidInstance() + { + throw new InvalidOperationException("Instance was not properly initialized."); + } + + internal static void ThrowInvalidOperationException_IncorrectTagType() + { + throw new InvalidOperationException("Tag type doesn't match requested data type."); + } + + internal static void ThrowException_InsufficientBufferSize() + { + throw new Exception("Acquired buffer did not have sufficient size."); + } +} From 87aa2f73e045db580d518fa90667827176ed3a0b Mon Sep 17 00:00:00 2001 From: Tides Date: Thu, 30 Jan 2025 07:43:04 -0600 Subject: [PATCH 2/3] Update NbtReader.Primitives.cs --- Obsidian.Nbt/NbtReader.Primitives.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/Obsidian.Nbt/NbtReader.Primitives.cs b/Obsidian.Nbt/NbtReader.Primitives.cs index d57c7fa0..a45a24ef 100644 --- a/Obsidian.Nbt/NbtReader.Primitives.cs +++ b/Obsidian.Nbt/NbtReader.Primitives.cs @@ -1,5 +1,4 @@ using System.Buffers.Binary; -using System.Text; namespace Obsidian.Nbt; public partial struct NbtReader From bca4e0ffc4635974e71238f7bf387ac01c08d750 Mon Sep 17 00:00:00 2001 From: Tides Date: Sat, 1 Feb 2025 04:18:12 -0600 Subject: [PATCH 3/3] Add avx2 supported check --- Obsidian.Nbt/ModifiedUtf8.cs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Obsidian.Nbt/ModifiedUtf8.cs b/Obsidian.Nbt/ModifiedUtf8.cs index 93e80978..63630b73 100644 --- a/Obsidian.Nbt/ModifiedUtf8.cs +++ b/Obsidian.Nbt/ModifiedUtf8.cs @@ -278,16 +278,16 @@ private static void GetStringCommon(ReadOnlySpan bytes, int stringLength, if (!BitConverter.IsLittleEndian || stringLength < 16) { GetStringAsciiScalar(bytes, ref destination); + return; } - else + else if (Avx2.IsSupported) { GetStringAsciiAvx2(bytes, ref destination); + return; } } - else - { - GetStringScalar(bytes, ref destination); - } + + GetStringScalar(bytes, ref destination); } private static void GetStringAsciiAvx2(ReadOnlySpan bytes, ref char destination)