diff --git a/benchmark/BDN.benchmark/Operations/OperationsBase.cs b/benchmark/BDN.benchmark/Operations/OperationsBase.cs index 7ba80eee8d..8d58631fe9 100644 --- a/benchmark/BDN.benchmark/Operations/OperationsBase.cs +++ b/benchmark/BDN.benchmark/Operations/OperationsBase.cs @@ -51,7 +51,8 @@ public virtual void GlobalSetup() { var opts = new GarnetServerOptions { - QuietMode = true + QuietMode = true, + EnableLua = true, }; if (Params.useAof) { diff --git a/benchmark/BDN.benchmark/Operations/ScriptOperations.cs b/benchmark/BDN.benchmark/Operations/ScriptOperations.cs new file mode 100644 index 0000000000..f068a7c0ac --- /dev/null +++ b/benchmark/BDN.benchmark/Operations/ScriptOperations.cs @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using BenchmarkDotNet.Attributes; + +namespace BDN.benchmark.Operations +{ + /// + /// Benchmark for SCRIPT LOAD, SCRIPT EXISTS, EVAL, and EVALSHA + /// + [MemoryDiagnoser] + public unsafe class ScriptOperations : OperationsBase + { + static ReadOnlySpan SCRIPT_LOAD => "*3\r\n$6\r\nSCRIPT\r\n$4\r\nLOAD\r\n$8\r\nreturn 1\r\n"u8; + byte[] scriptLoadRequestBuffer; + byte* scriptLoadRequestBufferPointer; + + static ReadOnlySpan SCRIPT_EXISTS_LOADED => "*3\r\n$6\r\nSCRIPT\r\n$4\r\nLOAD\r\n$10\r\nreturn nil\r\n"u8; + + static ReadOnlySpan SCRIPT_EXISTS_TRUE => "*3\r\n$6\r\nSCRIPT\r\n$6\r\nEXISTS\r\n$40\r\n79cefb99366d8809d2e903c5f36f50c2b731913f\r\n"u8; + byte[] scriptExistsTrueRequestBuffer; + byte* scriptExistsTrueRequestBufferPointer; + + static ReadOnlySpan SCRIPT_EXISTS_FALSE => "*3\r\n$6\r\nSCRIPT\r\n$6\r\nEXISTS\r\n$40\r\n0000000000000000000000000000000000000000\r\n"u8; + byte[] scriptExistsFalseRequestBuffer; + byte* scriptExistsFalseRequestBufferPointer; + + static ReadOnlySpan EVAL => "*3\r\n$4\r\nEVAL\r\n$10\r\nreturn nil\r\n$1\r\n0\r\n"u8; + byte[] evalRequestBuffer; + byte* evalRequestBufferPointer; + + static ReadOnlySpan EVALSHA => "*3\r\n$7\r\nEVALSHA\r\n$40\r\n79cefb99366d8809d2e903c5f36f50c2b731913f\r\n$1\r\n0\r\n"u8; + byte[] evalShaRequestBuffer; + byte* evalShaRequestBufferPointer; + + public override void GlobalSetup() + { + base.GlobalSetup(); + + SetupOperation(ref scriptLoadRequestBuffer, ref scriptLoadRequestBufferPointer, SCRIPT_LOAD); + + byte[] scriptExistsLoadedBuffer = null; + byte* scriptExistsLoadedPointer = null; + SetupOperation(ref scriptExistsLoadedBuffer, ref scriptExistsLoadedPointer, SCRIPT_EXISTS_LOADED); + _ = session.TryConsumeMessages(scriptExistsLoadedPointer, scriptExistsLoadedBuffer.Length); + SetupOperation(ref scriptExistsTrueRequestBuffer, ref scriptExistsTrueRequestBufferPointer, SCRIPT_EXISTS_TRUE); + + SetupOperation(ref scriptExistsFalseRequestBuffer, ref scriptExistsFalseRequestBufferPointer, SCRIPT_EXISTS_FALSE); + + SetupOperation(ref evalRequestBuffer, ref evalRequestBufferPointer, EVAL); + + SetupOperation(ref evalShaRequestBuffer, ref evalShaRequestBufferPointer, EVALSHA); + } + + [Benchmark] + public void ScriptLoad() + { + _ = session.TryConsumeMessages(scriptLoadRequestBufferPointer, scriptLoadRequestBuffer.Length); + } + + [Benchmark] + public void ScriptExistsTrue() + { + _ = session.TryConsumeMessages(scriptExistsTrueRequestBufferPointer, scriptExistsTrueRequestBuffer.Length); + } + + [Benchmark] + public void ScriptExistsFalse() + { + _ = session.TryConsumeMessages(scriptExistsFalseRequestBufferPointer, scriptExistsFalseRequestBuffer.Length); + } + + [Benchmark] + public void Eval() + { + _ = session.TryConsumeMessages(evalRequestBufferPointer, evalRequestBuffer.Length); + } + + [Benchmark] + public void EvalSha() + { + _ = session.TryConsumeMessages(evalShaRequestBufferPointer, evalShaRequestBuffer.Length); + } + } +} \ No newline at end of file diff --git a/libs/cluster/Session/SlotVerification/ClusterSlotVerify.cs b/libs/cluster/Session/SlotVerification/ClusterSlotVerify.cs index 76a6135e32..de71cd7d08 100644 --- a/libs/cluster/Session/SlotVerification/ClusterSlotVerify.cs +++ b/libs/cluster/Session/SlotVerification/ClusterSlotVerify.cs @@ -153,9 +153,9 @@ ClusterSlotVerificationResult MultiKeySlotVerify(ClusterConfig config, ref Sessi ref var key = ref parseState.GetArgSliceByRef(csvi.firstKey); var slot = ArgSliceUtils.HashSlot(ref key); var verifyResult = SingleKeySlotVerify(ref config, ref key, csvi.readOnly, csvi.sessionAsking, slot); - var stride = csvi.firstKey + csvi.step; + var secondKey = csvi.firstKey + csvi.step; - for (var i = stride; i < csvi.lastKey; i += stride) + for (var i = secondKey; i < csvi.lastKey; i += csvi.step) { if (csvi.keyNumOffset == i) continue; diff --git a/libs/common/AsciiUtils.cs b/libs/common/AsciiUtils.cs index e9b20b20c9..b24f61799f 100644 --- a/libs/common/AsciiUtils.cs +++ b/libs/common/AsciiUtils.cs @@ -51,6 +51,14 @@ public static void ToUpperInPlace(Span command) Ascii.ToUpperInPlace(command, out _); } + /// + /// Convert ASCII Span to lower case + /// + public static void ToLowerInPlace(Span command) + { + Ascii.ToLowerInPlace(command, out _); + } + /// public static bool EqualsUpperCaseSpanIgnoringCase(this Span left, ReadOnlySpan right) => EqualsUpperCaseSpanIgnoringCase((ReadOnlySpan)left, right); diff --git a/libs/resources/RespCommandsInfo.json b/libs/resources/RespCommandsInfo.json index dfaf6dcd5e..4c06e3e447 100644 --- a/libs/resources/RespCommandsInfo.json +++ b/libs/resources/RespCommandsInfo.json @@ -1146,14 +1146,12 @@ { "Command": "EXPIREAT", "Name": "EXPIREAT", - "IsInternal": false, "Arity": -3, "Flags": "Fast, Write", "FirstKey": 1, "LastKey": 1, "Step": 1, "AclCategories": "Fast, KeySpace, Write", - "Tips": null, "KeySpecifications": [ { "BeginSearch": { @@ -1166,23 +1164,19 @@ "KeyStep": 1, "Limit": 0 }, - "Notes": null, "Flags": "RW, Update" } - ], - "SubCommands": null + ] }, { "Command": "EXPIRETIME", "Name": "EXPIRETIME", - "IsInternal": false, "Arity": 2, "Flags": "Fast, ReadOnly", "FirstKey": 1, "LastKey": 1, "Step": 1, "AclCategories": "Fast, KeySpace, Read", - "Tips": null, "KeySpecifications": [ { "BeginSearch": { @@ -1195,11 +1189,9 @@ "KeyStep": 1, "Limit": 0 }, - "Notes": null, "Flags": "RO, Access" } - ], - "SubCommands": null + ] }, { "Command": "FAILOVER", @@ -1481,14 +1473,12 @@ { "Command": "GETEX", "Name": "GETEX", - "IsInternal": false, "Arity": -2, "Flags": "Fast, Write", "FirstKey": 1, "LastKey": 1, "Step": 1, "AclCategories": "Fast, String, Write", - "Tips": null, "KeySpecifications": [ { "BeginSearch": { @@ -1504,8 +1494,7 @@ "Notes": "RW and UPDATE because it changes the TTL", "Flags": "RW, Access, Update" } - ], - "SubCommands": null + ] }, { "Command": "GETRANGE", @@ -1535,14 +1524,12 @@ { "Command": "GETSET", "Name": "GETSET", - "IsInternal": false, "Arity": 3, "Flags": "DenyOom, Fast, Write", "FirstKey": 1, "LastKey": 1, "Step": 1, "AclCategories": "Fast, String, Write", - "Tips": null, "KeySpecifications": [ { "BeginSearch": { @@ -1555,11 +1542,9 @@ "KeyStep": 1, "Limit": 0 }, - "Notes": null, "Flags": "RW, Access, Update" } - ], - "SubCommands": null + ] }, { "Command": "HDEL", @@ -2036,14 +2021,12 @@ { "Command": "INCRBYFLOAT", "Name": "INCRBYFLOAT", - "IsInternal": false, "Arity": 3, "Flags": "DenyOom, Fast, Write", "FirstKey": 1, "LastKey": 1, "Step": 1, "AclCategories": "Fast, String, Write", - "Tips": null, "KeySpecifications": [ { "BeginSearch": { @@ -2056,11 +2039,9 @@ "KeyStep": 1, "Limit": 0 }, - "Notes": null, "Flags": "RW, Access, Update" } - ], - "SubCommands": null + ] }, { "Command": "INFO", @@ -2295,14 +2276,12 @@ { "Command": "LPOS", "Name": "LPOS", - "IsInternal": false, "Arity": -3, "Flags": "ReadOnly", "FirstKey": 1, "LastKey": 1, "Step": 1, "AclCategories": "List, Read, Slow", - "Tips": null, "KeySpecifications": [ { "BeginSearch": { @@ -2315,11 +2294,9 @@ "KeyStep": 1, "Limit": 0 }, - "Notes": null, "Flags": "RO, Access" } - ], - "SubCommands": null + ] }, { "Command": "LPUSH", @@ -2710,14 +2687,12 @@ { "Command": "PEXPIREAT", "Name": "PEXPIREAT", - "IsInternal": false, "Arity": -3, "Flags": "Fast, Write", "FirstKey": 1, "LastKey": 1, "Step": 1, "AclCategories": "Fast, KeySpace, Write", - "Tips": null, "KeySpecifications": [ { "BeginSearch": { @@ -2730,23 +2705,19 @@ "KeyStep": 1, "Limit": 0 }, - "Notes": null, "Flags": "RW, Update" } - ], - "SubCommands": null + ] }, { "Command": "PEXPIRETIME", "Name": "PEXPIRETIME", - "IsInternal": false, "Arity": 2, "Flags": "Fast, ReadOnly", "FirstKey": 1, "LastKey": 1, "Step": 1, "AclCategories": "Fast, KeySpace, Read", - "Tips": null, "KeySpecifications": [ { "BeginSearch": { @@ -2759,11 +2730,9 @@ "KeyStep": 1, "Limit": 0 }, - "Notes": null, "Flags": "RO, Access" } - ], - "SubCommands": null + ] }, { "Command": "PFADD", @@ -2971,16 +2940,12 @@ { "Command": "PURGEBP", "Name": "PURGEBP", - "IsInternal": false, "Arity": 2, "Flags": "Admin, NoMulti, NoScript, ReadOnly", "FirstKey": 1, "LastKey": 1, "Step": 1, - "AclCategories": "Admin, Garnet", - "Tips": null, - "KeySpecifications": null, - "SubCommands": null + "AclCategories": "Admin, Garnet" }, { "Command": "QUIT", @@ -3054,14 +3019,12 @@ { "Command": "RENAMENX", "Name": "RENAMENX", - "IsInternal": false, "Arity": 3, "Flags": "Fast, Write", "FirstKey": 1, "LastKey": 2, "Step": 1, "AclCategories": "Fast, KeySpace, Write", - "Tips": null, "KeySpecifications": [ { "BeginSearch": { @@ -3074,7 +3037,6 @@ "KeyStep": 1, "Limit": 0 }, - "Notes": null, "Flags": "RW, Access, Delete" }, { @@ -3088,11 +3050,9 @@ "KeyStep": 1, "Limit": 0 }, - "Notes": null, "Flags": "OW, Insert" } - ], - "SubCommands": null + ] }, { "Command": "REPLICAOF", @@ -3297,7 +3257,42 @@ "Command": "SCRIPT", "Name": "SCRIPT", "Arity": -2, - "AclCategories": "Slow" + "AclCategories": "Slow", + "SubCommands": [ + { + "Command": "SCRIPT_EXISTS", + "Name": "SCRIPT|EXISTS", + "Arity": -3, + "Flags": "NoScript", + "AclCategories": "Scripting, Slow", + "Tips": [ + "request_policy:all_shards", + "response_policy:agg_logical_and" + ] + }, + { + "Command": "SCRIPT_FLUSH", + "Name": "SCRIPT|FLUSH", + "Arity": -2, + "Flags": "NoScript", + "AclCategories": "Scripting, Slow", + "Tips": [ + "request_policy:all_nodes", + "response_policy:all_succeeded" + ] + }, + { + "Command": "SCRIPT_LOAD", + "Name": "SCRIPT|LOAD", + "Arity": 3, + "Flags": "NoScript, Stale", + "AclCategories": "Scripting, Slow", + "Tips": [ + "request_policy:all_nodes", + "response_policy:all_succeeded" + ] + } + ] }, { "Command": "SDIFF", @@ -3458,14 +3453,12 @@ { "Command": "SETNX", "Name": "SETNX", - "IsInternal": false, "Arity": 3, "Flags": "DenyOom, Fast, Write", "FirstKey": 1, "LastKey": 1, "Step": 1, "AclCategories": "Fast, String, Write", - "Tips": null, "KeySpecifications": [ { "BeginSearch": { @@ -3478,11 +3471,9 @@ "KeyStep": 1, "Limit": 0 }, - "Notes": null, "Flags": "OW, Insert" } - ], - "SubCommands": null + ] }, { "Command": "SETRANGE", @@ -3864,14 +3855,12 @@ { "Command": "SUBSTR", "Name": "SUBSTR", - "IsInternal": false, "Arity": 4, "Flags": "ReadOnly", "FirstKey": 1, "LastKey": 1, "Step": 1, "AclCategories": "Read, Slow, String", - "Tips": null, "KeySpecifications": [ { "BeginSearch": { @@ -3884,11 +3873,9 @@ "KeyStep": 1, "Limit": 0 }, - "Notes": null, "Flags": "RO, Access" } - ], - "SubCommands": null + ] }, { "Command": "SUNION", diff --git a/libs/server/Custom/CustomCommandManager.cs b/libs/server/Custom/CustomCommandManager.cs index 580748544b..4ca58d7b9a 100644 --- a/libs/server/Custom/CustomCommandManager.cs +++ b/libs/server/Custom/CustomCommandManager.cs @@ -2,8 +2,8 @@ // Licensed under the MIT license. using System; -using System.Collections.Generic; -using System.Threading; +using System.Collections.Concurrent; +using System.Diagnostics; namespace Garnet.server { @@ -12,147 +12,104 @@ namespace Garnet.server /// public class CustomCommandManager { - internal static readonly ushort StartOffset = (ushort)(RespCommandExtensions.LastValidCommand + 1); - internal static readonly int MaxRegistrations = 512 - StartOffset; // Temporary fix to reduce map sizes - internal static readonly byte TypeIdStartOffset = (byte)(GarnetObjectTypeExtensions.LastObjectType + 1); - internal static readonly int MaxTypeRegistrations = (byte)(GarnetObjectTypeExtensions.FirstSpecialObjectType) - TypeIdStartOffset; - - internal readonly CustomRawStringCommand[] rawStringCommandMap; - internal readonly CustomObjectCommandWrapper[] objectCommandMap; - internal readonly CustomTransaction[] transactionProcMap; - internal readonly CustomProcedureWrapper[] customProcedureMap; - internal int RawStringCommandId = 0; - internal int ObjectTypeId = 0; - internal int TransactionProcId = 0; - internal int CustomProcedureId = 0; - - internal int CustomCommandsInfoCount => CustomCommandsInfo.Count; - internal readonly Dictionary CustomCommandsInfo = new(StringComparer.OrdinalIgnoreCase); - internal readonly Dictionary CustomCommandsDocs = new(StringComparer.OrdinalIgnoreCase); + internal static readonly int MinMapSize = 8; + internal static readonly byte TypeIdStartOffset = byte.MaxValue - (byte)GarnetObjectTypeExtensions.FirstSpecialObjectType; + + private ConcurrentExpandableMap rawStringCommandMap; + private ConcurrentExpandableMap objectCommandMap; + private ConcurrentExpandableMap transactionProcMap; + private ConcurrentExpandableMap customProcedureMap; + + internal int CustomCommandsInfoCount => customCommandsInfo.Count; + internal readonly ConcurrentDictionary customCommandsInfo = new(StringComparer.OrdinalIgnoreCase); + internal readonly ConcurrentDictionary customCommandsDocs = new(StringComparer.OrdinalIgnoreCase); /// /// Create new custom command manager /// public CustomCommandManager() { - rawStringCommandMap = new CustomRawStringCommand[MaxRegistrations]; - objectCommandMap = new CustomObjectCommandWrapper[MaxTypeRegistrations]; - transactionProcMap = new CustomTransaction[MaxRegistrations]; // can increase up to byte.MaxValue - customProcedureMap = new CustomProcedureWrapper[MaxRegistrations]; + rawStringCommandMap = new ConcurrentExpandableMap(MinMapSize, + (ushort)RespCommand.INVALID - 1, + (ushort)RespCommandExtensions.LastValidCommand + 1); + objectCommandMap = new ConcurrentExpandableMap(MinMapSize, + (byte)GarnetObjectTypeExtensions.FirstSpecialObjectType - 1, + (byte)GarnetObjectTypeExtensions.LastObjectType + 1); + transactionProcMap = new ConcurrentExpandableMap(MinMapSize, 0, byte.MaxValue); + customProcedureMap = new ConcurrentExpandableMap(MinMapSize, 0, byte.MaxValue); } internal int Register(string name, CommandType type, CustomRawStringFunctions customFunctions, RespCommandsInfo commandInfo, RespCommandDocs commandDocs, long expirationTicks) { - int id = Interlocked.Increment(ref RawStringCommandId) - 1; - if (id >= MaxRegistrations) + if (!rawStringCommandMap.TryGetNextId(out var cmdId)) throw new Exception("Out of registration space"); - - rawStringCommandMap[id] = new CustomRawStringCommand(name, (ushort)id, type, customFunctions, expirationTicks); - if (commandInfo != null) CustomCommandsInfo.Add(name, commandInfo); - if (commandDocs != null) CustomCommandsDocs.Add(name, commandDocs); - return id; + Debug.Assert(cmdId <= ushort.MaxValue); + var newCmd = new CustomRawStringCommand(name, (ushort)cmdId, type, customFunctions, expirationTicks); + var setSuccessful = rawStringCommandMap.TrySetValue(cmdId, ref newCmd); + Debug.Assert(setSuccessful); + if (commandInfo != null) customCommandsInfo.AddOrUpdate(name, commandInfo, (_, _) => commandInfo); + if (commandDocs != null) customCommandsDocs.AddOrUpdate(name, commandDocs, (_, _) => commandDocs); + return cmdId; } internal int Register(string name, Func proc, RespCommandsInfo commandInfo = null, RespCommandDocs commandDocs = null) { - int id = Interlocked.Increment(ref TransactionProcId) - 1; - if (id >= MaxRegistrations) + if (!transactionProcMap.TryGetNextId(out var cmdId)) throw new Exception("Out of registration space"); - - transactionProcMap[id] = new CustomTransaction(name, (byte)id, proc); - if (commandInfo != null) CustomCommandsInfo.Add(name, commandInfo); - if (commandDocs != null) CustomCommandsDocs.Add(name, commandDocs); - return id; + Debug.Assert(cmdId <= byte.MaxValue); + + var newCmd = new CustomTransaction(name, (byte)cmdId, proc); + var setSuccessful = transactionProcMap.TrySetValue(cmdId, ref newCmd); + Debug.Assert(setSuccessful); + if (commandInfo != null) customCommandsInfo.AddOrUpdate(name, commandInfo, (_, _) => commandInfo); + if (commandDocs != null) customCommandsDocs.AddOrUpdate(name, commandDocs, (_, _) => commandDocs); + return cmdId; } internal int RegisterType(CustomObjectFactory factory) { - for (int i = 0; i < ObjectTypeId; i++) - if (objectCommandMap[i].factory == factory) - throw new Exception($"Type already registered with ID {i}"); - - int type; - do - { - type = Interlocked.Increment(ref ObjectTypeId) - 1; - if (type >= MaxTypeRegistrations) - throw new Exception("Out of registration space"); - } while (objectCommandMap[type] != null); + if (objectCommandMap.TryGetFirstId(c => c.factory == factory, out var dupRegistrationId)) + throw new Exception($"Type already registered with ID {dupRegistrationId}"); - objectCommandMap[type] = new CustomObjectCommandWrapper((byte)type, factory); - - return type; - } - - internal void RegisterType(int objectTypeId, CustomObjectFactory factory) - { - if (objectTypeId >= MaxTypeRegistrations) - throw new Exception("Type is outside registration space"); + if (!objectCommandMap.TryGetNextId(out var cmdId)) + throw new Exception("Out of registration space"); + Debug.Assert(cmdId <= byte.MaxValue); - if (ObjectTypeId <= objectTypeId) ObjectTypeId = objectTypeId + 1; - for (int i = 0; i < ObjectTypeId; i++) - if (objectCommandMap[i].factory == factory) - throw new Exception($"Type already registered with ID {i}"); + var newCmd = new CustomObjectCommandWrapper((byte)cmdId, factory); + var setSuccessful = objectCommandMap.TrySetValue(cmdId, ref newCmd); + Debug.Assert(setSuccessful); - objectCommandMap[objectTypeId] = new CustomObjectCommandWrapper((byte)objectTypeId, factory); + return cmdId; } - internal (int objectTypeId, int subCommand) Register(string name, CommandType commandType, CustomObjectFactory factory, RespCommandsInfo commandInfo, RespCommandDocs commandDocs) + internal (int objectTypeId, int subCommand) Register(string name, CommandType commandType, CustomObjectFactory factory, RespCommandsInfo commandInfo, RespCommandDocs commandDocs, CustomObjectFunctions customObjectFunctions = null) { - int objectTypeId = -1; - for (int i = 0; i < ObjectTypeId; i++) - { - if (objectCommandMap[i].factory == factory) { objectTypeId = i; break; } - } - - if (objectTypeId == -1) + if (!objectCommandMap.TryGetFirstId(c => c.factory == factory, out var typeId)) { - objectTypeId = Interlocked.Increment(ref ObjectTypeId) - 1; - if (objectTypeId >= MaxTypeRegistrations) + if (!objectCommandMap.TryGetNextId(out typeId)) throw new Exception("Out of registration space"); - objectCommandMap[objectTypeId] = new CustomObjectCommandWrapper((byte)objectTypeId, factory); - } - var wrapper = objectCommandMap[objectTypeId]; + Debug.Assert(typeId <= byte.MaxValue); - int subCommand = Interlocked.Increment(ref wrapper.CommandId) - 1; - if (subCommand >= byte.MaxValue) - throw new Exception("Out of registration space"); - wrapper.commandMap[subCommand] = new CustomObjectCommand(name, (byte)objectTypeId, (byte)subCommand, commandType, wrapper.factory); - - if (commandInfo != null) CustomCommandsInfo.Add(name, commandInfo); - if (commandDocs != null) CustomCommandsDocs.Add(name, commandDocs); - - return (objectTypeId, subCommand); - } - - internal (int objectTypeId, int subCommand) Register(string name, CommandType commandType, CustomObjectFactory factory, CustomObjectFunctions customObjectFunctions, RespCommandsInfo commandInfo, RespCommandDocs commandDocs) - { - var objectTypeId = -1; - for (var i = 0; i < ObjectTypeId; i++) - { - if (objectCommandMap[i].factory == factory) { objectTypeId = i; break; } - } - - if (objectTypeId == -1) - { - objectTypeId = Interlocked.Increment(ref ObjectTypeId) - 1; - if (objectTypeId >= MaxTypeRegistrations) - throw new Exception("Out of registration space"); - objectCommandMap[objectTypeId] = new CustomObjectCommandWrapper((byte)objectTypeId, factory); + var newCmd = new CustomObjectCommandWrapper((byte)typeId, factory); + var setSuccessful = objectCommandMap.TrySetValue(typeId, ref newCmd); + Debug.Assert(setSuccessful); } - var wrapper = objectCommandMap[objectTypeId]; - - int subCommand = Interlocked.Increment(ref wrapper.CommandId) - 1; - if (subCommand >= byte.MaxValue) + objectCommandMap.TryGetValue(typeId, out var wrapper); + if (!wrapper.commandMap.TryGetNextId(out var scId)) throw new Exception("Out of registration space"); - wrapper.commandMap[subCommand] = new CustomObjectCommand(name, (byte)objectTypeId, (byte)subCommand, commandType, wrapper.factory, customObjectFunctions); - if (commandInfo != null) CustomCommandsInfo.Add(name, commandInfo); - if (commandDocs != null) CustomCommandsDocs.Add(name, commandDocs); + Debug.Assert(scId <= byte.MaxValue); + var newSubCmd = new CustomObjectCommand(name, (byte)typeId, (byte)scId, commandType, wrapper.factory, + customObjectFunctions); + var scSetSuccessful = wrapper.commandMap.TrySetValue(scId, ref newSubCmd); + Debug.Assert(scSetSuccessful); - return (objectTypeId, subCommand); + if (commandInfo != null) customCommandsInfo.AddOrUpdate(name, commandInfo, (_, _) => commandInfo); + if (commandDocs != null) customCommandsDocs.AddOrUpdate(name, commandDocs, (_, _) => commandDocs); + + return (typeId, scId); } /// @@ -166,80 +123,59 @@ internal void RegisterType(int objectTypeId, CustomObjectFactory factory) /// internal int Register(string name, Func customProcedure, RespCommandsInfo commandInfo = null, RespCommandDocs commandDocs = null) { - int id = Interlocked.Increment(ref CustomProcedureId) - 1; - if (id >= MaxRegistrations) + if (!customProcedureMap.TryGetNextId(out var cmdId)) throw new Exception("Out of registration space"); - customProcedureMap[id] = new CustomProcedureWrapper(name, (byte)id, customProcedure, this); - if (commandInfo != null) CustomCommandsInfo.Add(name, commandInfo); - if (commandDocs != null) CustomCommandsDocs.Add(name, commandDocs); - return id; + Debug.Assert(cmdId <= byte.MaxValue); + + var newCmd = new CustomProcedureWrapper(name, (byte)cmdId, customProcedure, this); + var setSuccessful = customProcedureMap.TrySetValue(cmdId, ref newCmd); + Debug.Assert(setSuccessful); + + if (commandInfo != null) customCommandsInfo.AddOrUpdate(name, commandInfo, (_, _) => commandInfo); + if (commandDocs != null) customCommandsDocs.AddOrUpdate(name, commandDocs, (_, _) => commandDocs); + return cmdId; } - internal bool Match(ReadOnlySpan command, out CustomRawStringCommand cmd) + internal bool TryGetCustomProcedure(int id, out CustomProcedureWrapper value) + => customProcedureMap.TryGetValue(id, out value); + + internal bool TryGetCustomTransactionProcedure(int id, out CustomTransaction value) + => transactionProcMap.TryGetValue(id, out value); + + internal bool TryGetCustomCommand(int id, out CustomRawStringCommand value) + => rawStringCommandMap.TryGetValue(id, out value); + + internal bool TryGetCustomObjectCommand(int id, out CustomObjectCommandWrapper value) + => objectCommandMap.TryGetValue(id, out value); + + internal bool TryGetCustomObjectSubCommand(int id, int subId, out CustomObjectCommand value) { - for (int i = 0; i < RawStringCommandId; i++) - { - cmd = rawStringCommandMap[i]; - if (cmd != null && command.SequenceEqual(new ReadOnlySpan(cmd.name))) - return true; - } - cmd = null; - return false; + value = default; + return objectCommandMap.TryGetValue(id, out var wrapper) && + wrapper.commandMap.TryGetValue(subId, out value); } + internal bool Match(ReadOnlySpan command, out CustomRawStringCommand cmd) + => rawStringCommandMap.MatchCommandSafe(command, out cmd); + internal bool Match(ReadOnlySpan command, out CustomTransaction cmd) - { - for (int i = 0; i < TransactionProcId; i++) - { - cmd = transactionProcMap[i]; - if (cmd != null && command.SequenceEqual(new ReadOnlySpan(cmd.name))) - return true; - } - cmd = null; - return false; - } + => transactionProcMap.MatchCommandSafe(command, out cmd); internal bool Match(ReadOnlySpan command, out CustomObjectCommand cmd) - { - for (int i = 0; i < ObjectTypeId; i++) - { - var wrapper = objectCommandMap[i]; - if (wrapper != null) - { - for (int j = 0; j < wrapper.CommandId; j++) - { - cmd = wrapper.commandMap[j]; - if (cmd != null && command.SequenceEqual(new ReadOnlySpan(cmd.name))) - return true; - } - } - else break; - } - cmd = null; - return false; - } + => objectCommandMap.MatchSubCommandSafe(command, out cmd); internal bool Match(ReadOnlySpan command, out CustomProcedureWrapper cmd) - { - for (int i = 0; i < CustomProcedureId; i++) - { - cmd = customProcedureMap[i]; - if (cmd != null && command.SequenceEqual(new ReadOnlySpan(cmd.Name))) - return true; - } - cmd = null; - return false; - } + => customProcedureMap.MatchCommandSafe(command, out cmd); internal bool TryGetCustomCommandInfo(string cmdName, out RespCommandsInfo respCommandsInfo) { - return this.CustomCommandsInfo.TryGetValue(cmdName, out respCommandsInfo); + return this.customCommandsInfo.TryGetValue(cmdName, out respCommandsInfo); } internal bool TryGetCustomCommandDocs(string cmdName, out RespCommandDocs respCommandsDocs) { - return this.CustomCommandsDocs.TryGetValue(cmdName, out respCommandsDocs); + return this.customCommandsDocs.TryGetValue(cmdName, out respCommandsDocs); } } } \ No newline at end of file diff --git a/libs/server/Custom/CustomCommandManagerSession.cs b/libs/server/Custom/CustomCommandManagerSession.cs index 8cf7e4ba1f..5664a561ac 100644 --- a/libs/server/Custom/CustomCommandManagerSession.cs +++ b/libs/server/Custom/CustomCommandManagerSession.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +using System.Diagnostics; using Garnet.common; namespace Garnet.server @@ -13,53 +14,81 @@ internal sealed class CustomCommandManagerSession readonly CustomCommandManager customCommandManager; // These session specific arrays are indexed by the same ID as the arrays in CustomCommandManager - readonly (CustomTransactionProcedure, int)[] sessionTransactionProcMap; - readonly CustomProcedure[] sessionCustomProcMap; - + ExpandableMap sessionTransactionProcMap; + ExpandableMap sessionCustomProcMap; public CustomCommandManagerSession(CustomCommandManager customCommandManager) { this.customCommandManager = customCommandManager; - sessionTransactionProcMap = new (CustomTransactionProcedure, int)[CustomCommandManager.MaxRegistrations]; - sessionCustomProcMap = new CustomProcedure[CustomCommandManager.MaxRegistrations]; + sessionTransactionProcMap = new ExpandableMap(CustomCommandManager.MinMapSize, 0, byte.MaxValue); + sessionCustomProcMap = new ExpandableMap(CustomCommandManager.MinMapSize, 0, byte.MaxValue); } public CustomProcedure GetCustomProcedure(int id, RespServerSession respServerSession) { - if (sessionCustomProcMap[id] == null) + if (!sessionCustomProcMap.TryGetValue(id, out var customProc)) { - var entry = customCommandManager.customProcedureMap[id] ?? throw new GarnetException($"Custom procedure {id} not found"); - sessionCustomProcMap[id] = entry.CustomProcedureFactory(); - sessionCustomProcMap[id].respServerSession = respServerSession; + if (!customCommandManager.TryGetCustomProcedure(id, out var entry)) + throw new GarnetException($"Custom procedure {id} not found"); + + customProc = entry.CustomProcedureFactory(); + customProc.respServerSession = respServerSession; + var setSuccessful = sessionCustomProcMap.TrySetValue(id, ref customProc); + Debug.Assert(setSuccessful); } - return sessionCustomProcMap[id]; + return customProc; } - public (CustomTransactionProcedure, int) GetCustomTransactionProcedure(int id, RespServerSession respServerSession, TransactionManager txnManager, ScratchBufferManager scratchBufferManager) + public CustomTransactionProcedure GetCustomTransactionProcedure(int id, RespServerSession respServerSession, TransactionManager txnManager, ScratchBufferManager scratchBufferManager, out int arity) { - if (sessionTransactionProcMap[id].Item1 == null) + if (sessionTransactionProcMap.Exists(id)) { - var entry = customCommandManager.transactionProcMap[id] ?? throw new GarnetException($"Transaction procedure {id} not found"); - _ = customCommandManager.CustomCommandsInfo.TryGetValue(entry.NameStr, out var cmdInfo); - return GetCustomTransactionProcedure(entry, respServerSession, txnManager, scratchBufferManager, cmdInfo?.Arity ?? 0); + ref var customTranProc = ref sessionTransactionProcMap.GetValueByRef(id); + if (customTranProc.Procedure != null) + { + arity = customTranProc.Arity; + return customTranProc.Procedure; + } } - return sessionTransactionProcMap[id]; + + if (!customCommandManager.TryGetCustomTransactionProcedure(id, out var entry)) + throw new GarnetException($"Transaction procedure {id} not found"); + _ = customCommandManager.customCommandsInfo.TryGetValue(entry.NameStr, out var cmdInfo); + arity = cmdInfo?.Arity ?? 0; + return GetCustomTransactionProcedureAndSetArity(entry, respServerSession, txnManager, scratchBufferManager, cmdInfo?.Arity ?? 0); } - public (CustomTransactionProcedure, int) GetCustomTransactionProcedure(CustomTransaction entry, RespServerSession respServerSession, TransactionManager txnManager, ScratchBufferManager scratchBufferManager, int arity) + private CustomTransactionProcedure GetCustomTransactionProcedureAndSetArity(CustomTransaction entry, RespServerSession respServerSession, TransactionManager txnManager, ScratchBufferManager scratchBufferManager, int arity) { int id = entry.id; - if (sessionTransactionProcMap[id].Item1 == null) + + var customTranProc = new CustomTransactionProcedureWithArity(entry.proc(), arity) { - sessionTransactionProcMap[id].Item1 = entry.proc(); - sessionTransactionProcMap[id].Item2 = arity; + Procedure = + { + txnManager = txnManager, + scratchBufferManager = scratchBufferManager, + respServerSession = respServerSession + } + }; + var setSuccessful = sessionTransactionProcMap.TrySetValue(id, ref customTranProc); + Debug.Assert(setSuccessful); + + return customTranProc.Procedure; + } + + private struct CustomTransactionProcedureWithArity + { + public CustomTransactionProcedure Procedure { get; } - sessionTransactionProcMap[id].Item1.txnManager = txnManager; - sessionTransactionProcMap[id].Item1.scratchBufferManager = scratchBufferManager; - sessionTransactionProcMap[id].Item1.respServerSession = respServerSession; + public int Arity { get; } + + public CustomTransactionProcedureWithArity(CustomTransactionProcedure procedure, int arity) + { + this.Procedure = procedure; + this.Arity = arity; } - return sessionTransactionProcMap[id]; } } } \ No newline at end of file diff --git a/libs/server/Custom/CustomCommandRegistration.cs b/libs/server/Custom/CustomCommandRegistration.cs index ddf7f830a7..a03f4421df 100644 --- a/libs/server/Custom/CustomCommandRegistration.cs +++ b/libs/server/Custom/CustomCommandRegistration.cs @@ -233,9 +233,9 @@ public override void Register(CustomCommandManager customCommandManager) RegisterArgs.Name, RegisterArgs.CommandType, factory, - RegisterArgs.ObjectCommand, RegisterArgs.CommandInfo, - RegisterArgs.CommandDocs); + RegisterArgs.CommandDocs, + RegisterArgs.ObjectCommand); } } diff --git a/libs/server/Custom/CustomObjectCommand.cs b/libs/server/Custom/CustomObjectCommand.cs index 96e7d168da..0f7ec804a5 100644 --- a/libs/server/Custom/CustomObjectCommand.cs +++ b/libs/server/Custom/CustomObjectCommand.cs @@ -3,10 +3,11 @@ namespace Garnet.server { - public class CustomObjectCommand + public class CustomObjectCommand : ICustomCommand { + public byte[] Name { get; } + public readonly string NameStr; - public readonly byte[] name; public readonly byte id; public readonly byte subid; public readonly CommandType type; @@ -16,14 +17,12 @@ public class CustomObjectCommand internal CustomObjectCommand(string name, byte id, byte subid, CommandType type, CustomObjectFactory factory, CustomObjectFunctions functions = null) { NameStr = name.ToUpperInvariant(); - this.name = System.Text.Encoding.ASCII.GetBytes(NameStr); + this.Name = System.Text.Encoding.ASCII.GetBytes(NameStr); this.id = id; this.subid = subid; this.type = type; this.factory = factory; this.functions = functions; } - - internal GarnetObjectType GetObjectType() => (GarnetObjectType)(id + CustomCommandManager.TypeIdStartOffset); } } \ No newline at end of file diff --git a/libs/server/Custom/CustomObjectCommandWrapper.cs b/libs/server/Custom/CustomObjectCommandWrapper.cs index 57b8ce4194..5c5a0d5ce8 100644 --- a/libs/server/Custom/CustomObjectCommandWrapper.cs +++ b/libs/server/Custom/CustomObjectCommandWrapper.cs @@ -8,16 +8,18 @@ namespace Garnet.server /// class CustomObjectCommandWrapper { + static readonly int MinMapSize = 8; + static readonly byte MaxSubId = 31; // RespInputHeader uses the 3 MSBs of SubId, so SubId must fit in the 5 LSBs + public readonly byte id; public readonly CustomObjectFactory factory; - public int CommandId = 0; - public readonly CustomObjectCommand[] commandMap; + public ConcurrentExpandableMap commandMap; public CustomObjectCommandWrapper(byte id, CustomObjectFactory functions) { this.id = id; this.factory = functions; - this.commandMap = new CustomObjectCommand[byte.MaxValue]; + this.commandMap = new ConcurrentExpandableMap(MinMapSize, 0, MaxSubId); } } } \ No newline at end of file diff --git a/libs/server/Custom/CustomProcedureWrapper.cs b/libs/server/Custom/CustomProcedureWrapper.cs index fa7b1e2349..aac96b0b93 100644 --- a/libs/server/Custom/CustomProcedureWrapper.cs +++ b/libs/server/Custom/CustomProcedureWrapper.cs @@ -22,10 +22,11 @@ public abstract bool Execute(TGarnetApi garnetApi, ref CustomProcedu where TGarnetApi : IGarnetApi; } - class CustomProcedureWrapper + class CustomProcedureWrapper : ICustomCommand { + public byte[] Name { get; } + public readonly string NameStr; - public readonly byte[] Name; public readonly byte Id; public readonly Func CustomProcedureFactory; diff --git a/libs/server/Custom/CustomRawStringCommand.cs b/libs/server/Custom/CustomRawStringCommand.cs index 0959cab9f1..1dec27cf9d 100644 --- a/libs/server/Custom/CustomRawStringCommand.cs +++ b/libs/server/Custom/CustomRawStringCommand.cs @@ -3,10 +3,11 @@ namespace Garnet.server { - public class CustomRawStringCommand + public class CustomRawStringCommand : ICustomCommand { + public byte[] Name { get; } + public readonly string NameStr; - public readonly byte[] name; public readonly ushort id; public readonly CommandType type; public readonly CustomRawStringFunctions functions; @@ -15,13 +16,11 @@ public class CustomRawStringCommand internal CustomRawStringCommand(string name, ushort id, CommandType type, CustomRawStringFunctions functions, long expirationTicks) { NameStr = name.ToUpperInvariant(); - this.name = System.Text.Encoding.ASCII.GetBytes(NameStr); + this.Name = System.Text.Encoding.ASCII.GetBytes(NameStr); this.id = id; this.type = type; this.functions = functions; this.expirationTicks = expirationTicks; } - - internal RespCommand GetRespCommand() => (RespCommand)(id + CustomCommandManager.StartOffset); } } \ No newline at end of file diff --git a/libs/server/Custom/CustomRespCommands.cs b/libs/server/Custom/CustomRespCommands.cs index f8e7e5f4b1..1c1362923f 100644 --- a/libs/server/Custom/CustomRespCommands.cs +++ b/libs/server/Custom/CustomRespCommands.cs @@ -52,7 +52,7 @@ private bool TryTransactionProc(byte id, CustomTransactionProcedure proc, int st public bool RunTransactionProc(byte id, ref CustomProcedureInput procInput, ref MemoryResult output) { var proc = customCommandManagerSession - .GetCustomTransactionProcedure(id, this, txnManager, scratchBufferManager).Item1; + .GetCustomTransactionProcedure(id, this, txnManager, scratchBufferManager, out _); return txnManager.RunTransactionProc(id, ref procInput, proc, ref output); } @@ -226,7 +226,7 @@ public bool InvokeCustomRawStringCommand(ref TGarnetApi storageApi, var sbKey = key.SpanByte; var inputArg = customCommand.expirationTicks > 0 ? DateTimeOffset.UtcNow.Ticks + customCommand.expirationTicks : customCommand.expirationTicks; customCommandParseState.InitializeWithArguments(args); - var rawStringInput = new RawStringInput(customCommand.GetRespCommand(), ref customCommandParseState, arg1: inputArg); + var rawStringInput = new RawStringInput((RespCommand)customCommand.id, ref customCommandParseState, arg1: inputArg); var _output = new SpanByteAndMemory(null); GarnetStatus status; @@ -290,7 +290,7 @@ public bool InvokeCustomObjectCommand(ref TGarnetApi storageApi, Cus var keyBytes = key.ToArray(); // Prepare input - var header = new RespInputHeader(customObjCommand.GetObjectType()) { SubId = customObjCommand.subid }; + var header = new RespInputHeader((GarnetObjectType)customObjCommand.id) { SubId = customObjCommand.subid }; customCommandParseState.InitializeWithArguments(args); var input = new ObjectInput(header, ref customCommandParseState); diff --git a/libs/server/Custom/CustomTransaction.cs b/libs/server/Custom/CustomTransaction.cs index 7e42170444..0a7a851a23 100644 --- a/libs/server/Custom/CustomTransaction.cs +++ b/libs/server/Custom/CustomTransaction.cs @@ -6,10 +6,11 @@ namespace Garnet.server { - class CustomTransaction + class CustomTransaction : ICustomCommand { + public byte[] Name { get; } + public readonly string NameStr; - public readonly byte[] name; public readonly byte id; public readonly Func proc; @@ -18,7 +19,7 @@ internal CustomTransaction(string name, byte id, Func + /// This interface describes an API for a map of items of type T whose keys are a specified range of IDs (can be descending / ascending) + /// The size of the underlying array containing the items doubles in size as needed. + /// + /// + internal interface IExpandableMap + { + /// + /// Checks if ID is mapped to a value in underlying array + /// + /// Item ID + /// True if ID exists + bool Exists(int id); + + /// + /// Try to get item by ID + /// + /// Item ID + /// Item value + /// True if item found + bool TryGetValue(int id, out T value); + + /// + /// Try to get item by ref by ID + /// + /// Item ID + /// Item value + ref T GetValueByRef(int id); + + /// + /// Try to set item by ID + /// + /// Item ID + /// Item value + /// True if actual size of map should be updated (true by default) + /// True if assignment succeeded + bool TrySetValue(int id, ref T value, bool updateSize = true); + + /// + /// Get next item ID for assignment + /// + /// Item ID + /// True if item ID available + bool TryGetNextId(out int id); + + /// + /// Find first ID in map of item that fulfills specified predicate + /// + /// Predicate + /// ID if found, otherwise -1 + /// True if ID found + bool TryGetFirstId(Func predicate, out int id); + } + + /// + /// This struct defines a map of items of type T whose keys are a specified range of IDs (can be descending / ascending) + /// The size of the underlying array containing the items doubles in size as needed. + /// This struct is not thread-safe, for a thread-safe option see ConcurrentExpandableMap. + /// + /// Type of item to store + internal struct ExpandableMap : IExpandableMap + { + /// + /// The underlying array containing the items + /// + internal T[] Map { get; private set; } + + /// + /// The actual size of the map + /// i.e. the max index of an inserted item + 1 (not the size of the underlying array) + /// + internal int ActualSize { get; private set; } + + // The last requested index for assignment + int currIndex = -1; + // Initial array size + readonly int minSize; + // Value of min item ID + readonly int minId; + // Value of max item ID + readonly int maxSize; + // True if item IDs are in descending order + readonly bool descIds; + + /// + /// Creates a new instance of ExpandableMap + /// + /// Initial size of underlying array + /// The minimal item ID value + /// The maximal item ID value (can be smaller than minId for descending order of IDs) + public ExpandableMap(int minSize, int minId, int maxId) + { + this.Map = null; + this.minSize = minSize; + this.minId = minId; + this.maxSize = Math.Abs(maxId - minId) + 1; + this.descIds = minId > maxId; + } + + /// + public bool TryGetValue(int id, out T value) + { + value = default; + var idx = GetIndexFromId(id); + if (idx < 0 || idx >= ActualSize) + return false; + + value = Map[idx]; + return true; + } + + /// + public bool Exists(int id) + { + var idx = GetIndexFromId(id); + return idx >= 0 && idx < ActualSize; + } + + /// + public ref T GetValueByRef(int id) + { + var idx = GetIndexFromId(id); + if (idx < 0 || idx >= ActualSize) + throw new ArgumentOutOfRangeException(nameof(idx)); + + return ref Map[idx]; + } + + /// + public bool TrySetValue(int id, ref T value, bool updateSize = true) => + TrySetValue(id, ref value, false, updateSize); + + /// + public bool TryGetNextId(out int id) + { + id = -1; + var nextIdx = ++currIndex; + + if (nextIdx >= maxSize) + return false; + id = GetIdFromIndex(nextIdx); + + return true; + } + + /// + public bool TryGetFirstId(Func predicate, out int id) + { + id = -1; + for (var i = 0; i < ActualSize; i++) + { + if (predicate(Map[i])) + { + id = GetIdFromIndex(i); + return true; + } + } + + return false; + } + + /// + /// Get next item ID for assignment with atomic incrementation of underlying index + /// + /// Item ID + /// True if item ID available + public bool TryGetNextIdSafe(out int id) + { + id = -1; + var nextIdx = Interlocked.Increment(ref currIndex); + + if (nextIdx >= maxSize) + return false; + id = GetIdFromIndex(nextIdx); + + return true; + } + + /// + /// Try to update the actual size of the map based on the inserted item ID + /// + /// The inserted item ID + /// True if should not do actual update + /// True if actual size should be updated (or was updated if noUpdate is false) + internal bool TryUpdateSize(int id, bool noUpdate = false) + { + var idx = GetIndexFromId(id); + + // Should not update the size if the index is out of bounds + // or if index is smaller than the current actual size + if (idx < 0 || idx < ActualSize || idx >= maxSize) return false; + + if (!noUpdate) + ActualSize = idx + 1; + + return true; + } + + /// + /// Try to set item by ID + /// + /// Item ID + /// Item value + /// True if should not attempt to expand the underlying array + /// True if should update actual size of the map + /// True if assignment succeeded + internal bool TrySetValue(int id, ref T value, bool noExpansion, bool updateSize) + { + var idx = GetIndexFromId(id); + if (idx < 0 || idx >= maxSize) return false; + + // If index within array bounds, set item + if (Map != null && idx < Map.Length) + { + Map[idx] = value; + if (updateSize) TryUpdateSize(id); + return true; + } + + if (noExpansion) return false; + + // Double new array size until item can fit + var newSize = Map != null ? Math.Max(Map.Length, minSize) : minSize; + while (idx >= newSize) + { + newSize = Math.Min(maxSize, newSize * 2); + } + + // Create new array, copy existing items and set new item + var newMap = new T[newSize]; + if (Map != null) + { + Array.Copy(Map, newMap, Map.Length); + } + + Map = newMap; + Map[idx] = value; + if (updateSize) TryUpdateSize(id); + return true; + } + + /// + /// Maps map index to item ID + /// + /// Map index + /// Item ID + private int GetIdFromIndex(int index) => descIds ? minId - index : index; + + /// + /// Maps an item ID to a map index + /// + /// Item ID + /// Map index + private int GetIndexFromId(int id) => descIds ? minId - id : id; + } + + /// + /// This struct defines a map of items of type T whose keys are a specified range of IDs (can be descending / ascending) + /// The size of the underlying array containing the items doubles in size as needed + /// This struct is thread-safe with regard to the underlying array pointer. + /// + /// Type of item to store + internal struct ConcurrentExpandableMap : IExpandableMap + { + /// + /// Reader-writer lock for the underlying item array + /// + internal SingleWriterMultiReaderLock eMapLock = new(); + + /// + /// The underlying non-concurrent ExpandableMap (should be accessed using the eMapLock) + /// + internal ExpandableMap eMapUnsafe; + + /// + /// Creates a new instance of ConcurrentExpandableMap + /// + /// Initial size of underlying array + /// The minimal item ID value + /// The maximal item ID value (can be smaller than minId for descending order of IDs) + public ConcurrentExpandableMap(int minSize, int minId, int maxId) + { + this.eMapUnsafe = new ExpandableMap(minSize, minId, maxId); + } + + /// + public bool TryGetValue(int id, out T value) + { + value = default; + eMapLock.ReadLock(); + try + { + return eMapUnsafe.TryGetValue(id, out value); + } + finally + { + eMapLock.ReadUnlock(); + } + } + + /// + public bool Exists(int id) + { + eMapLock.ReadLock(); + try + { + return eMapUnsafe.Exists(id); + } + finally + { + eMapLock.ReadUnlock(); + } + } + + /// + public ref T GetValueByRef(int id) + { + eMapLock.ReadLock(); + try + { + return ref eMapUnsafe.GetValueByRef(id); + } + finally + { + eMapLock.ReadUnlock(); + } + } + + /// + public bool TrySetValue(int id, ref T value, bool updateSize = true) + { + var shouldUpdateSize = false; + + // Try to perform set without taking a write lock first + eMapLock.ReadLock(); + try + { + // Try to set value without expanding map + if (eMapUnsafe.TrySetValue(id, ref value, true, false)) + { + // Check if map size should be updated + if (!updateSize || !eMapUnsafe.TryUpdateSize(id, true)) + return true; + shouldUpdateSize = true; + } + } + finally + { + eMapLock.ReadUnlock(); + } + + eMapLock.WriteLock(); + try + { + // Value already set, just update map size + if (shouldUpdateSize) + { + eMapUnsafe.TryUpdateSize(id); + return true; + } + + // Try to set value with expanding the map, if needed + return eMapUnsafe.TrySetValue(id, ref value, false, true); + } + finally + { + eMapLock.WriteUnlock(); + } + } + + /// + public bool TryGetNextId(out int id) + { + return eMapUnsafe.TryGetNextIdSafe(out id); + } + + /// + public bool TryGetFirstId(Func predicate, out int id) + { + id = -1; + eMapLock.ReadLock(); + try + { + return eMapUnsafe.TryGetFirstId(predicate, out id); + } + finally + { + eMapLock.ReadUnlock(); + } + } + } + + /// + /// Extension methods for ConcurrentExpandableMap + /// + internal static class ConcurrentExpandableMapExtensions + { + /// + /// Match command name with existing commands in map and return first matching instance + /// + /// Type of command + /// Current instance of ConcurrentExpandableMap + /// Command name to match + /// Value of command found + /// True if command found + internal static bool MatchCommandSafe(this ConcurrentExpandableMap eMap, ReadOnlySpan cmd, out T value) + where T : ICustomCommand + { + value = default; + eMap.eMapLock.ReadLock(); + try + { + for (var i = 0; i < eMap.eMapUnsafe.ActualSize; i++) + { + var currCmd = eMap.eMapUnsafe.Map[i]; + if (currCmd != null && cmd.SequenceEqual(new ReadOnlySpan(currCmd.Name))) + { + value = currCmd; + return true; + } + } + } + finally + { + eMap.eMapLock.ReadUnlock(); + } + + return false; + } + + /// + /// Match sub-command name with existing sub-commands in map and return first matching instance + /// + /// Type of command + /// Current instance of ConcurrentExpandableMap + /// Sub-command name to match + /// Value of sub-command found + /// + internal static bool MatchSubCommandSafe(this ConcurrentExpandableMap eMap, ReadOnlySpan cmd, out CustomObjectCommand value) + where T : CustomObjectCommandWrapper + { + value = default; + eMap.eMapLock.ReadLock(); + try + { + for (var i = 0; i < eMap.eMapUnsafe.ActualSize; i++) + { + if (eMap.eMapUnsafe.Map[i] != null && eMap.eMapUnsafe.Map[i].commandMap.MatchCommandSafe(cmd, out value)) + return true; + } + } + finally + { + eMap.eMapLock.ReadUnlock(); + } + + return false; + } + } +} \ No newline at end of file diff --git a/libs/server/Custom/ICustomCommand.cs b/libs/server/Custom/ICustomCommand.cs new file mode 100644 index 0000000000..7a0a21a4f5 --- /dev/null +++ b/libs/server/Custom/ICustomCommand.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +namespace Garnet.server +{ + /// + /// Interface for custom commands + /// + interface ICustomCommand + { + /// + /// Name of command + /// + byte[] Name { get; } + } +} \ No newline at end of file diff --git a/libs/server/Lua/LuaCommands.cs b/libs/server/Lua/LuaCommands.cs index a0e61368a4..f306262719 100644 --- a/libs/server/Lua/LuaCommands.cs +++ b/libs/server/Lua/LuaCommands.cs @@ -2,11 +2,12 @@ // Licensed under the MIT license. using System; -using System.Text; +using System.Buffers; using Garnet.common; using Microsoft.Extensions.Logging; using NLua; using NLua.Exceptions; +using Tsavorite.core; namespace Garnet.server { @@ -18,10 +19,8 @@ internal sealed unsafe partial class RespServerSession : ServerSessionBase /// private unsafe bool TryEVALSHA() { - if (!storeWrapper.serverOptions.EnableLua) + if (!CheckLuaEnabled()) { - while (!RespWriteUtils.WriteError(CmdStrings.RESP_ERR_LUA_DISABLED, ref dcurr, dend)) - SendAndReset(); return true; } @@ -30,23 +29,28 @@ private unsafe bool TryEVALSHA() { return AbortWithWrongNumberOfArguments("EVALSHA"); } - var digest = parseState.GetArgSliceByRef(0).ReadOnlySpan; + + ref var digest = ref parseState.GetArgSliceByRef(0); + AsciiUtils.ToLowerInPlace(digest.Span); + + var digestAsSpanByteMem = new SpanByteAndMemory(digest.SpanByte); var result = false; - if (!sessionScriptCache.TryGetFromDigest(digest, out var runner)) + if (!sessionScriptCache.TryGetFromDigest(digestAsSpanByteMem, out var runner)) { - var d = digest.ToArray(); - if (storeWrapper.storeScriptCache.TryGetValue(d, out var source)) + if (storeWrapper.storeScriptCache.TryGetValue(digestAsSpanByteMem, out var source)) { - if (!sessionScriptCache.TryLoad(source, d, out runner, out var error)) + if (!sessionScriptCache.TryLoad(source, digestAsSpanByteMem, out runner, out var error)) { while (!RespWriteUtils.WriteError(error, ref dcurr, dend)) SendAndReset(); - _ = storeWrapper.storeScriptCache.TryRemove(d, out _); + + _ = storeWrapper.storeScriptCache.TryRemove(digestAsSpanByteMem, out _); return result; } } } + if (runner == null) { while (!RespWriteUtils.WriteError(CmdStrings.RESP_ERR_NO_SCRIPT, ref dcurr, dend)) @@ -56,6 +60,7 @@ private unsafe bool TryEVALSHA() { result = ExecuteScript(count - 1, runner); } + return result; } @@ -66,10 +71,8 @@ private unsafe bool TryEVALSHA() /// private unsafe bool TryEVAL() { - if (!storeWrapper.serverOptions.EnableLua) + if (!CheckLuaEnabled()) { - while (!RespWriteUtils.WriteError(CmdStrings.RESP_ERR_LUA_DISABLED, ref dcurr, dend)) - SendAndReset(); return true; } @@ -78,16 +81,22 @@ private unsafe bool TryEVAL() { return AbortWithWrongNumberOfArguments("EVAL"); } - var script = parseState.GetArgSliceByRef(0).ReadOnlySpan; - var digest = sessionScriptCache.GetScriptDigest(script); + + var script = parseState.GetArgSliceByRef(0).ToArray(); + + // that this is stack allocated is load bearing - if it moves, things will break + Span digest = stackalloc byte[SessionScriptCache.SHA1Len]; + sessionScriptCache.GetScriptDigest(script, digest); var result = false; - if (!sessionScriptCache.TryLoad(script, digest, out var runner, out var error)) + if (!sessionScriptCache.TryLoad(script, new SpanByteAndMemory(SpanByte.FromPinnedSpan(digest)), out var runner, out var error)) { while (!RespWriteUtils.WriteError(error, ref dcurr, dend)) SendAndReset(); + return result; } + if (runner == null) { while (!RespWriteUtils.WriteError(CmdStrings.RESP_ERR_NO_SCRIPT, ref dcurr, dend)) @@ -97,90 +106,137 @@ private unsafe bool TryEVAL() { result = ExecuteScript(count - 1, runner); } + return result; } /// - /// SCRIPT Commands (load, exists, flush) + /// SCRIPT|EXISTS /// - /// - private unsafe bool TrySCRIPT() + private bool NetworkScriptExists() { - if (!storeWrapper.serverOptions.EnableLua) + if (!CheckLuaEnabled()) { - while (!RespWriteUtils.WriteError(CmdStrings.RESP_ERR_LUA_DISABLED, ref dcurr, dend)) - SendAndReset(); return true; } - var count = parseState.Count; - if (count < 1) + if (parseState.Count == 0) { - return AbortWithWrongNumberOfArguments("SCRIPT"); + return AbortWithWrongNumberOfArguments("script|exists"); } - var option = parseState.GetArgSliceByRef(0).ReadOnlySpan; - if (option.EqualsUpperCaseSpanIgnoringCase("LOAD"u8)) + + // returns an array where each element is a 0 if the script does not exist, and a 1 if it does + + while (!RespWriteUtils.WriteArrayLength(parseState.Count, ref dcurr, dend)) + SendAndReset(); + + for (var shaIx = 0; shaIx < parseState.Count; shaIx++) { - if (count != 2) - { - return AbortWithWrongNumberOfArguments("SCRIPT"); - } - var source = parseState.GetArgSliceByRef(1).ReadOnlySpan; - if (!sessionScriptCache.TryLoad(source, out var digest, out _, out var error)) - { - while (!RespWriteUtils.WriteError(error, ref dcurr, dend)) - SendAndReset(); - return true; - } + ref var sha1 = ref parseState.GetArgSliceByRef(shaIx); + AsciiUtils.ToLowerInPlace(sha1.Span); - // Add script to the store dictionary - storeWrapper.storeScriptCache.TryAdd(digest, source.ToArray()); + var sha1Arg = new SpanByteAndMemory(sha1.SpanByte); - while (!RespWriteUtils.WriteBulkString(digest, ref dcurr, dend)) + var exists = storeWrapper.storeScriptCache.ContainsKey(sha1Arg) ? 1 : 0; + + while (!RespWriteUtils.WriteArrayItem(exists, ref dcurr, dend)) SendAndReset(); } - else if (option.EqualsUpperCaseSpanIgnoringCase("EXISTS"u8)) + + return true; + } + + /// + /// SCRIPT|FLUSH + /// + private bool NetworkScriptFlush() + { + if (!CheckLuaEnabled()) { - if (count != 2) - { - return AbortWithWrongNumberOfArguments("SCRIPT"); - } - var sha1Exists = parseState.GetArgSliceByRef(1).ToArray(); + return true; + } - // Check whether script exists at the store level - if (storeWrapper.storeScriptCache.ContainsKey(sha1Exists)) - { - while (!RespWriteUtils.WriteBulkString(CmdStrings.RESP_OK.ToArray(), ref dcurr, dend)) - SendAndReset(); - } - else - { - while (!RespWriteUtils.WriteBulkString(CmdStrings.RESP_RETURN_VAL_N1.ToArray(), ref dcurr, dend)) - SendAndReset(); - } + if (parseState.Count > 1) + { + return AbortWithErrorMessage(CmdStrings.RESP_ERR_SCRIPT_FLUSH_OPTIONS); } - else if (option.EqualsUpperCaseSpanIgnoringCase("FLUSH"u8)) + else if (parseState.Count == 1) { - if (count != 1) + // we ignore this, but should validate it + ref var arg = ref parseState.GetArgSliceByRef(0); + + AsciiUtils.ToUpperInPlace(arg.Span); + + var valid = arg.Span.SequenceEqual(CmdStrings.ASYNC) || arg.Span.SequenceEqual(CmdStrings.SYNC); + + if (!valid) { - return AbortWithWrongNumberOfArguments("SCRIPT"); + return AbortWithErrorMessage(CmdStrings.RESP_ERR_SCRIPT_FLUSH_OPTIONS); } - // Flush store script cache - storeWrapper.storeScriptCache.Clear(); + } + + // Flush store script cache + storeWrapper.storeScriptCache.Clear(); + + // Flush session script cache + sessionScriptCache.Clear(); - // Flush session script cache - sessionScriptCache.Clear(); + while (!RespWriteUtils.WriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) + SendAndReset(); - while (!RespWriteUtils.WriteDirect(CmdStrings.RESP_OK.ToArray(), ref dcurr, dend)) + return true; + } + + /// + /// SCRIPT|LOAD + /// + private bool NetworkScriptLoad() + { + if (!CheckLuaEnabled()) + { + return true; + } + + if (parseState.Count != 1) + { + return AbortWithWrongNumberOfArguments("script|load"); + } + + var source = parseState.GetArgSliceByRef(0).ToArray(); + if (!sessionScriptCache.TryLoad(source, out var digest, out _, out var error)) + { + while (!RespWriteUtils.WriteError(error, ref dcurr, dend)) SendAndReset(); } else { - // Unknown subcommand - var errorMsg = string.Format(CmdStrings.GenericErrUnknownSubCommand, Encoding.ASCII.GetString(option), nameof(RespCommand.SCRIPT)); - while (!RespWriteUtils.WriteError(errorMsg, ref dcurr, dend)) + + // Add script to the store dictionary + var scriptKey = new SpanByteAndMemory(new ScriptHashOwner(digest.AsMemory()), digest.Length); + _ = storeWrapper.storeScriptCache.TryAdd(scriptKey, source); + + while (!RespWriteUtils.WriteBulkString(digest, ref dcurr, dend)) + SendAndReset(); + } + + return true; + } + + /// + /// Returns true if Lua is enabled. + /// + /// Otherwise writes out an error and returns false. + /// + private bool CheckLuaEnabled() + { + if (!storeWrapper.serverOptions.EnableLua) + { + while (!RespWriteUtils.WriteError(CmdStrings.RESP_ERR_LUA_DISABLED, ref dcurr, dend)) SendAndReset(); + + return false; } + return true; } diff --git a/libs/server/Lua/SessionScriptCache.cs b/libs/server/Lua/SessionScriptCache.cs index 8d5ab8ff33..eddc0d9ffc 100644 --- a/libs/server/Lua/SessionScriptCache.cs +++ b/libs/server/Lua/SessionScriptCache.cs @@ -3,34 +3,38 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Security.Cryptography; +using Garnet.common; using Garnet.server.ACL; using Garnet.server.Auth; using Microsoft.Extensions.Logging; +using Tsavorite.core; namespace Garnet.server { /// /// Cache of Lua scripts, per session /// - internal sealed unsafe class SessionScriptCache : IDisposable + internal sealed class SessionScriptCache : IDisposable { // Important to keep the hash length to this value // for compatibility - const int SHA1Len = 40; + internal const int SHA1Len = 40; readonly RespServerSession processor; readonly ScratchBufferNetworkSender scratchBufferNetworkSender; readonly StoreWrapper storeWrapper; readonly ILogger logger; - readonly Dictionary scriptCache = new(new ByteArrayComparer()); + readonly Dictionary scriptCache = new(SpanByteAndMemoryComparer.Instance); readonly byte[] hash = new byte[SHA1Len / 2]; public SessionScriptCache(StoreWrapper storeWrapper, IGarnetAuthenticator authenticator, ILogger logger = null) { - this.scratchBufferNetworkSender = new ScratchBufferNetworkSender(); this.storeWrapper = storeWrapper; - this.processor = new RespServerSession(0, scratchBufferNetworkSender, storeWrapper, null, authenticator, false); this.logger = logger; + + scratchBufferNetworkSender = new ScratchBufferNetworkSender(); + processor = new RespServerSession(0, scratchBufferNetworkSender, storeWrapper, null, authenticator, false); } public void Dispose() @@ -48,22 +52,23 @@ public void SetUser(User user) /// /// Try get script runner for given digest /// - public bool TryGetFromDigest(ReadOnlySpan digest, out LuaRunner scriptRunner) - => scriptCache.TryGetValue(digest.ToArray(), out scriptRunner); + public bool TryGetFromDigest(SpanByteAndMemory digest, out LuaRunner scriptRunner) + => scriptCache.TryGetValue(digest, out scriptRunner); /// /// Load script into the cache /// - public bool TryLoad(ReadOnlySpan source, out byte[] digest, out LuaRunner runner, out string error) + public bool TryLoad(byte[] source, out byte[] digest, out LuaRunner runner, out string error) { - digest = GetScriptDigest(source); - return TryLoad(source, digest, out runner, out error); + digest = new byte[SHA1Len]; + GetScriptDigest(source, digest); + + return TryLoad(source, new SpanByteAndMemory(new ScriptHashOwner(digest), digest.Length), out runner, out error); } - internal bool TryLoad(ReadOnlySpan source, byte[] digest, out LuaRunner runner, out string error) + internal bool TryLoad(byte[] source, SpanByteAndMemory digest, out LuaRunner runner, out string error) { error = null; - runner = null; if (scriptCache.TryGetValue(digest, out runner)) return true; @@ -72,13 +77,25 @@ internal bool TryLoad(ReadOnlySpan source, byte[] digest, out LuaRunner ru { runner = new LuaRunner(source, storeWrapper.serverOptions.LuaTransactionMode, processor, scratchBufferNetworkSender, logger); runner.Compile(); - scriptCache.TryAdd(digest, runner); + + // need to make sure the key is on the heap, so move it over if needed + var storeKeyDigest = digest; + if (storeKeyDigest.IsSpanByte) + { + var into = new byte[storeKeyDigest.Length]; + storeKeyDigest.AsReadOnlySpan().CopyTo(into); + + storeKeyDigest = new SpanByteAndMemory(new ScriptHashOwner(into), into.Length); + } + + _ = scriptCache.TryAdd(storeKeyDigest, runner); } catch (Exception ex) { error = ex.Message; return false; } + return true; } @@ -91,21 +108,23 @@ public void Clear() { runner.Dispose(); } + scriptCache.Clear(); } static ReadOnlySpan HEX_CHARS => "0123456789abcdef"u8; - public byte[] GetScriptDigest(ReadOnlySpan source) + public void GetScriptDigest(ReadOnlySpan source, Span into) { - var digest = new byte[SHA1Len]; - SHA1.HashData(source, new Span(hash)); - for (int i = 0; i < 20; i++) + Debug.Assert(into.Length >= SHA1Len, "into must be large enough for the hash"); + + _ = SHA1.HashData(source, new Span(hash)); + + for (var i = 0; i < hash.Length; i++) { - digest[i * 2] = HEX_CHARS[hash[i] >> 4]; - digest[i * 2 + 1] = HEX_CHARS[hash[i] & 0x0F]; + into[i * 2] = HEX_CHARS[hash[i] >> 4]; + into[i * 2 + 1] = HEX_CHARS[hash[i] & 0x0F]; } - return digest; } } } \ No newline at end of file diff --git a/libs/server/Module/ModuleRegistrar.cs b/libs/server/Module/ModuleRegistrar.cs index 891b2df33d..623384b96c 100644 --- a/libs/server/Module/ModuleRegistrar.cs +++ b/libs/server/Module/ModuleRegistrar.cs @@ -152,7 +152,7 @@ public ModuleActionStatus RegisterCommand(string name, CustomObjectFactory facto if (string.IsNullOrEmpty(name) || factory == null || command == null) return ModuleActionStatus.InvalidRegistrationInfo; - customCommandManager.Register(name, type, factory, command, commandInfo, commandDocs); + customCommandManager.Register(name, type, factory, commandInfo, commandDocs, command); return ModuleActionStatus.Success; } diff --git a/libs/server/Objects/ItemBroker/CollectionItemBroker.cs b/libs/server/Objects/ItemBroker/CollectionItemBroker.cs index ac3aba7a9d..82a2158b07 100644 --- a/libs/server/Objects/ItemBroker/CollectionItemBroker.cs +++ b/libs/server/Objects/ItemBroker/CollectionItemBroker.cs @@ -34,7 +34,7 @@ public class CollectionItemBroker : IDisposable private readonly Lazy> brokerEventsQueueLazy = new(); private readonly Lazy> sessionIdToObserverLazy = new(); private readonly Lazy>> keysToObserversLazy = - new(() => new Dictionary>(new ByteArrayComparer())); + new(() => new Dictionary>(ByteArrayComparer.Instance)); // Cancellation token for the main loop private readonly CancellationTokenSource cts = new(); diff --git a/libs/server/Objects/Types/GarnetObjectSerializer.cs b/libs/server/Objects/Types/GarnetObjectSerializer.cs index 2563371269..1ebdc7b20d 100644 --- a/libs/server/Objects/Types/GarnetObjectSerializer.cs +++ b/libs/server/Objects/Types/GarnetObjectSerializer.cs @@ -13,14 +13,14 @@ namespace Garnet.server /// public sealed class GarnetObjectSerializer : BinaryObjectSerializer { - readonly CustomObjectCommandWrapper[] customCommands; + readonly CustomCommandManager customCommandManager; /// /// Constructor /// public GarnetObjectSerializer(CustomCommandManager customCommandManager) { - this.customCommands = customCommandManager.objectCommandMap; + this.customCommandManager = customCommandManager; } /// @@ -58,8 +58,9 @@ private IGarnetObject DeserializeInternal(BinaryReader binaryReader) private IGarnetObject CustomDeserialize(byte type, BinaryReader binaryReader) { - if (type < CustomCommandManager.TypeIdStartOffset) return null; - return customCommands[type - CustomCommandManager.TypeIdStartOffset].factory.Deserialize(type, binaryReader); + if (type < CustomCommandManager.TypeIdStartOffset || + !customCommandManager.TryGetCustomObjectCommand(type, out var cmd)) return null; + return cmd.factory.Deserialize(type, binaryReader); } /// diff --git a/libs/server/Objects/Types/GarnetObjectType.cs b/libs/server/Objects/Types/GarnetObjectType.cs index 69ad2e793b..ddbc40f8f4 100644 --- a/libs/server/Objects/Types/GarnetObjectType.cs +++ b/libs/server/Objects/Types/GarnetObjectType.cs @@ -33,6 +33,11 @@ public enum GarnetObjectType : byte // Any new special type inserted here should update GarnetObjectTypeExtensions.FirstSpecialObjectType + /// + /// Special type indicating PEXPIRE command + /// + PExpire = 0xf8, + /// /// Special type indicating EXPIRETIME command /// @@ -44,40 +49,35 @@ public enum GarnetObjectType : byte PExpireTime = 0xfa, /// - /// Special type indicating PERSIST command - /// - Persist = 0xfd, - - /// - /// Special type indicating TTL command + /// Indicating a Custom Object command /// - Ttl = 0xfe, + All = 0xfb, /// - /// Special type indicating EXPIRE command + /// Special type indicating PTTL command /// - Expire = 0xff, + PTtl = 0xfc, /// - /// Special type indicating PEXPIRE command + /// Special type indicating PERSIST command /// - PExpire = 0xf8, + Persist = 0xfd, /// - /// Special type indicating PTTL command + /// Special type indicating TTL command /// - PTtl = 0xfc, + Ttl = 0xfe, /// - /// Indicating a Custom Object command + /// Special type indicating EXPIRE command /// - All = 0xfb + Expire = 0xff, } public static class GarnetObjectTypeExtensions { internal const GarnetObjectType LastObjectType = GarnetObjectType.Set; - internal const GarnetObjectType FirstSpecialObjectType = GarnetObjectType.ExpireTime; + internal const GarnetObjectType FirstSpecialObjectType = GarnetObjectType.PExpire; } } \ No newline at end of file diff --git a/libs/server/Resp/BasicCommands.cs b/libs/server/Resp/BasicCommands.cs index cbb2008504..860dd691b2 100644 --- a/libs/server/Resp/BasicCommands.cs +++ b/libs/server/Resp/BasicCommands.cs @@ -988,7 +988,7 @@ private void WriteCOMMANDResponse() var resultSb = new StringBuilder(); var cmdCount = 0; - foreach (var customCmd in storeWrapper.customCommandManager.CustomCommandsInfo.Values) + foreach (var customCmd in storeWrapper.customCommandManager.customCommandsInfo.Values) { cmdCount++; resultSb.Append(customCmd.RespFormat); @@ -1082,7 +1082,7 @@ private bool NetworkCOMMAND_DOCS() resultSb.Append(cmdDocs.RespFormat); } - foreach (var customCmd in storeWrapper.customCommandManager.CustomCommandsDocs.Values) + foreach (var customCmd in storeWrapper.customCommandManager.customCommandsDocs.Values) { docsCount++; resultSb.Append(customCmd.RespFormat); diff --git a/libs/server/Resp/ByteArrayComparer.cs b/libs/server/Resp/ByteArrayComparer.cs index f1f3fd891c..b726652759 100644 --- a/libs/server/Resp/ByteArrayComparer.cs +++ b/libs/server/Resp/ByteArrayComparer.cs @@ -22,6 +22,8 @@ public sealed class ByteArrayComparer : IEqualityComparer public bool Equals(byte[] left, byte[] right) => new ReadOnlySpan(left).SequenceEqual(new ReadOnlySpan(right)); + private ByteArrayComparer() { } + /// public unsafe int GetHashCode(byte[] key) { diff --git a/libs/server/Resp/CmdStrings.cs b/libs/server/Resp/CmdStrings.cs index e0e40afddf..aad5e5e43f 100644 --- a/libs/server/Resp/CmdStrings.cs +++ b/libs/server/Resp/CmdStrings.cs @@ -211,6 +211,7 @@ static partial class CmdStrings public static ReadOnlySpan RESP_ERR_GT_LT_NX_NOT_COMPATIBLE => "ERR GT, LT, and/or NX options at the same time are not compatible"u8; public static ReadOnlySpan RESP_ERR_INCR_SUPPORTS_ONLY_SINGLE_PAIR => "ERR INCR option supports a single increment-element pair"u8; public static ReadOnlySpan RESP_ERR_INVALID_BITFIELD_TYPE => "ERR Invalid bitfield type. Use something like i16 u8. Note that u64 is not supported but i64 is"u8; + public static ReadOnlySpan RESP_ERR_SCRIPT_FLUSH_OPTIONS => "ERR SCRIPT FLUSH only support SYNC|ASYNC option"u8; /// /// Response string templates diff --git a/libs/server/Resp/Parser/RespCommand.cs b/libs/server/Resp/Parser/RespCommand.cs index 1473a014b1..5cf5077592 100644 --- a/libs/server/Resp/Parser/RespCommand.cs +++ b/libs/server/Resp/Parser/RespCommand.cs @@ -243,6 +243,9 @@ public enum RespCommand : ushort // Script commands SCRIPT, + SCRIPT_EXISTS, + SCRIPT_FLUSH, + SCRIPT_LOAD, ACL, ACL_CAT, @@ -1161,7 +1164,39 @@ private RespCommand FastParseArrayCommand(ref int count, ref ReadOnlySpan } else if (*(ulong*)(ptr + 4) == MemoryMarshal.Read("SCRIPT\r\n"u8)) { - return RespCommand.SCRIPT; + // SCRIPT EXISTS => "$6\r\nEXISTS\r\n".Length == 12 + // SCRIPT FLUSH => "$5\r\nFLUSH\r\n".Length == 11 + // SCRIPT LOAD => "$4\r\nLOAD\r\n".Length == 10 + + if (remainingBytes >= length + 10) + { + if (*(ulong*)(ptr + 4 + 8) == MemoryMarshal.Read("$4\r\nLOAD"u8) && *(ulong*)(ptr + 4 + 8 + 2) == MemoryMarshal.Read("\r\nLOAD\r\n"u8)) + { + count--; + readHead += 10; + return RespCommand.SCRIPT_LOAD; + } + + if (remainingBytes >= length + 11) + { + if (*(ulong*)(ptr + 4 + 8) == MemoryMarshal.Read("$5\r\nFLUS"u8) && *(ulong*)(ptr + 4 + 8 + 3) == MemoryMarshal.Read("\nFLUSH\r\n"u8)) + { + count--; + readHead += 11; + return RespCommand.SCRIPT_FLUSH; + } + + if (remainingBytes >= length + 12) + { + if (*(ulong*)(ptr + 4 + 8) == MemoryMarshal.Read("$6\r\nEXIS"u8) && *(ulong*)(ptr + 4 + 8 + 4) == MemoryMarshal.Read("EXISTS\r\n"u8)) + { + count--; + readHead += 12; + return RespCommand.SCRIPT_EXISTS; + } + } + } + } } break; diff --git a/libs/server/Resp/RespCommandsInfo.cs b/libs/server/Resp/RespCommandsInfo.cs index a2ddfbf147..0da1797398 100644 --- a/libs/server/Resp/RespCommandsInfo.cs +++ b/libs/server/Resp/RespCommandsInfo.cs @@ -200,11 +200,11 @@ private static bool TryInitializeRespCommandsInfo(ILogger logger = null) ) ); - FastBasicRespCommandsInfo = new RespCommandsInfo[(int)RespCommandExtensions.LastDataCommand - (int)RespCommandExtensions.FirstReadCommand]; - for (var i = (int)RespCommandExtensions.FirstReadCommand; i < (int)RespCommandExtensions.LastDataCommand; i++) + FastBasicRespCommandsInfo = new RespCommandsInfo[(int)RespCommandExtensions.LastDataCommand - (int)RespCommandExtensions.FirstReadCommand + 1]; + for (var i = (int)RespCommandExtensions.FirstReadCommand; i <= (int)RespCommandExtensions.LastDataCommand; i++) { FlattenedRespCommandsInfo.TryGetValue((RespCommand)i, out var commandInfo); - FastBasicRespCommandsInfo[i - 1] = commandInfo; + FastBasicRespCommandsInfo[i - (int)RespCommandExtensions.FirstReadCommand] = commandInfo; } return true; @@ -344,7 +344,7 @@ public static bool TryFastGetRespCommandInfo(RespCommand cmd, out RespCommandsIn respCommandsInfo = null; if (!IsInitialized && !TryInitialize(logger)) return false; - var offset = (int)cmd - 1; + var offset = (int)cmd - (int)RespCommandExtensions.FirstReadCommand; if (offset < 0 || offset >= FastBasicRespCommandsInfo.Length) return false; diff --git a/libs/server/Resp/RespServerSession.cs b/libs/server/Resp/RespServerSession.cs index 9ccd49a9f9..a5b00f020c 100644 --- a/libs/server/Resp/RespServerSession.cs +++ b/libs/server/Resp/RespServerSession.cs @@ -746,7 +746,10 @@ private bool ProcessOtherCommands(RespCommand command, ref TGarnetAp RespCommand.SCAN => NetworkSCAN(ref storageApi), RespCommand.TYPE => NetworkTYPE(ref storageApi), // Script Commands - RespCommand.SCRIPT => TrySCRIPT(), + RespCommand.SCRIPT_EXISTS => NetworkScriptExists(), + RespCommand.SCRIPT_FLUSH => NetworkScriptFlush(), + RespCommand.SCRIPT_LOAD => NetworkScriptLoad(), + RespCommand.EVAL => TryEVAL(), RespCommand.EVALSHA => TryEVALSHA(), _ => Process(command) @@ -776,8 +779,8 @@ bool NetworkCustomTxn() // Perform the operation TryTransactionProc(currentCustomTransaction.id, customCommandManagerSession - .GetCustomTransactionProcedure(currentCustomTransaction.id, this, txnManager, scratchBufferManager) - .Item1); + .GetCustomTransactionProcedure(currentCustomTransaction.id, this, txnManager, + scratchBufferManager, out _)); currentCustomTransaction = null; return true; } @@ -816,7 +819,7 @@ private bool NetworkCustomRawStringCmd(ref TGarnetApi storageApi) } // Perform the operation - TryCustomRawStringCommand(currentCustomRawStringCommand.GetRespCommand(), + TryCustomRawStringCommand((RespCommand)currentCustomRawStringCommand.id, currentCustomRawStringCommand.expirationTicks, currentCustomRawStringCommand.type, ref storageApi); currentCustomRawStringCommand = null; return true; @@ -832,7 +835,7 @@ bool NetworkCustomObjCmd(ref TGarnetApi storageApi) } // Perform the operation - TryCustomObjectCommand(currentCustomObjectCommand.GetObjectType(), currentCustomObjectCommand.subid, + TryCustomObjectCommand((GarnetObjectType)currentCustomObjectCommand.id, currentCustomObjectCommand.subid, currentCustomObjectCommand.type, ref storageApi); currentCustomObjectCommand = null; return true; @@ -840,7 +843,7 @@ bool NetworkCustomObjCmd(ref TGarnetApi storageApi) private bool IsCommandArityValid(string cmdName, int count) { - if (storeWrapper.customCommandManager.CustomCommandsInfo.TryGetValue(cmdName, out var cmdInfo)) + if (storeWrapper.customCommandManager.customCommandsInfo.TryGetValue(cmdName, out var cmdInfo)) { Debug.Assert(cmdInfo != null, "Custom command info should not be null"); if ((cmdInfo.Arity > 0 && count != cmdInfo.Arity - 1) || diff --git a/libs/server/Resp/ScriptHashOwner.cs b/libs/server/Resp/ScriptHashOwner.cs new file mode 100644 index 0000000000..c980ed7cca --- /dev/null +++ b/libs/server/Resp/ScriptHashOwner.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Buffers; + +namespace Garnet.server +{ + /// + /// Owner for memory used to store Lua script hashes on the heap. + /// + internal sealed class ScriptHashOwner : IMemoryOwner + { + private readonly Memory mem; + + /// + public Memory Memory => mem; + + internal ScriptHashOwner(Memory hashMem) + { + mem = hashMem; + } + + /// + public void Dispose() + { + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/SpanByteAndMemoryComparer.cs b/libs/server/Resp/SpanByteAndMemoryComparer.cs new file mode 100644 index 0000000000..e3ecc4dded --- /dev/null +++ b/libs/server/Resp/SpanByteAndMemoryComparer.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Collections.Generic; +using Tsavorite.core; + +namespace Garnet.server +{ + /// + /// equality comparer. + /// + public sealed class SpanByteAndMemoryComparer : IEqualityComparer + { + /// + /// The default instance. + /// + /// Used to avoid allocating new comparers. + public static readonly SpanByteAndMemoryComparer Instance = new(); + + private SpanByteAndMemoryComparer() { } + + /// + public bool Equals(SpanByteAndMemory left, SpanByteAndMemory right) + => left.AsReadOnlySpan().SequenceEqual(right.AsReadOnlySpan()); + + /// + public unsafe int GetHashCode(SpanByteAndMemory key) + { + var hash = new HashCode(); + hash.AddBytes(key.AsReadOnlySpan()); + + var ret = hash.ToHashCode(); + + return ret; + } + } +} \ No newline at end of file diff --git a/libs/server/Servers/RegisterApi.cs b/libs/server/Servers/RegisterApi.cs index ed2995280b..fd90cf9073 100644 --- a/libs/server/Servers/RegisterApi.cs +++ b/libs/server/Servers/RegisterApi.cs @@ -57,14 +57,6 @@ public int NewTransactionProc(string name, Func proc public int NewType(CustomObjectFactory factory) => provider.StoreWrapper.customCommandManager.RegisterType(factory); - /// - /// Register object type with server, with specific type ID [0-55] - /// - /// Type ID for factory - /// Factory for object type - public void NewType(int type, CustomObjectFactory factory) - => provider.StoreWrapper.customCommandManager.RegisterType(type, factory); - /// /// Register custom command with Garnet /// @@ -76,7 +68,7 @@ public void NewType(int type, CustomObjectFactory factory) /// RESP command docs /// ID of the registered command public (int objectTypeId, int subCommandId) NewCommand(string name, CommandType commandType, CustomObjectFactory factory, CustomObjectFunctions customObjectFunctions, RespCommandsInfo commandInfo = null, RespCommandDocs commandDocs = null) - => provider.StoreWrapper.customCommandManager.Register(name, commandType, factory, customObjectFunctions, commandInfo, commandDocs); + => provider.StoreWrapper.customCommandManager.Register(name, commandType, factory, commandInfo, commandDocs, customObjectFunctions); /// /// Register custom procedure with Garnet diff --git a/libs/server/Storage/Functions/FunctionsState.cs b/libs/server/Storage/Functions/FunctionsState.cs index bb2aa8e16e..055ad9f675 100644 --- a/libs/server/Storage/Functions/FunctionsState.cs +++ b/libs/server/Storage/Functions/FunctionsState.cs @@ -11,25 +11,33 @@ namespace Garnet.server /// internal sealed class FunctionsState { + private readonly CustomCommandManager customCommandManager; + public readonly TsavoriteLog appendOnlyFile; - public readonly CustomRawStringCommand[] customCommands; - public readonly CustomObjectCommandWrapper[] customObjectCommands; public readonly WatchVersionMap watchVersionMap; public readonly MemoryPool memoryPool; public readonly CacheSizeTracker objectStoreSizeTracker; public readonly GarnetObjectSerializer garnetObjectSerializer; public bool StoredProcMode; - public FunctionsState(TsavoriteLog appendOnlyFile, WatchVersionMap watchVersionMap, CustomRawStringCommand[] customCommands, CustomObjectCommandWrapper[] customObjectCommands, + public FunctionsState(TsavoriteLog appendOnlyFile, WatchVersionMap watchVersionMap, CustomCommandManager customCommandManager, MemoryPool memoryPool, CacheSizeTracker objectStoreSizeTracker, GarnetObjectSerializer garnetObjectSerializer) { this.appendOnlyFile = appendOnlyFile; this.watchVersionMap = watchVersionMap; - this.customCommands = customCommands; - this.customObjectCommands = customObjectCommands; + this.customCommandManager = customCommandManager; this.memoryPool = memoryPool ?? MemoryPool.Shared; this.objectStoreSizeTracker = objectStoreSizeTracker; this.garnetObjectSerializer = garnetObjectSerializer; } + + public CustomRawStringFunctions GetCustomCommandFunctions(int id) + => customCommandManager.TryGetCustomCommand(id, out var cmd) ? cmd.functions : null; + + public CustomObjectFactory GetCustomObjectFactory(int id) + => customCommandManager.TryGetCustomObjectCommand(id, out var cmd) ? cmd.factory : null; + + public CustomObjectFunctions GetCustomObjectSubCommandFunctions(int id, int subId) + => customCommandManager.TryGetCustomObjectSubCommand(id, subId, out var cmd) ? cmd.functions : null; } } \ No newline at end of file diff --git a/libs/server/Storage/Functions/MainStore/RMWMethods.cs b/libs/server/Storage/Functions/MainStore/RMWMethods.cs index ec5b8b1462..9d875ac357 100644 --- a/libs/server/Storage/Functions/MainStore/RMWMethods.cs +++ b/libs/server/Storage/Functions/MainStore/RMWMethods.cs @@ -30,11 +30,10 @@ public bool NeedInitialUpdate(ref SpanByte key, ref RawStringInput input, ref Sp case RespCommand.GETEX: return false; default: - if ((ushort)input.header.cmd >= CustomCommandManager.StartOffset) + if (input.header.cmd > RespCommandExtensions.LastValidCommand) { (IMemoryOwner Memory, int Length) outp = (output.Memory, 0); - var ret = functionsState - .customCommands[(ushort)input.header.cmd - CustomCommandManager.StartOffset].functions + var ret = functionsState.GetCustomCommandFunctions((ushort)input.header.cmd) .NeedInitialUpdate(key.AsReadOnlySpan(), ref input, ref outp); output.Memory = outp.Memory; output.Length = outp.Length; @@ -178,9 +177,9 @@ public bool InitialUpdater(ref SpanByte key, ref RawStringInput input, ref SpanB default: value.UnmarkExtraMetadata(); - if ((ushort)input.header.cmd >= CustomCommandManager.StartOffset) + if (input.header.cmd > RespCommandExtensions.LastValidCommand) { - var functions = functionsState.customCommands[(ushort)input.header.cmd - CustomCommandManager.StartOffset].functions; + var functions = functionsState.GetCustomCommandFunctions((ushort)input.header.cmd); // compute metadata size for result var expiration = input.arg1; metadataSize = expiration switch @@ -505,10 +504,10 @@ private bool InPlaceUpdaterWorker(ref SpanByte key, ref RawStringInput input, re return false; default: - var cmd = (ushort)input.header.cmd; - if (cmd >= CustomCommandManager.StartOffset) + var cmd = input.header.cmd; + if (cmd > RespCommandExtensions.LastValidCommand) { - var functions = functionsState.customCommands[cmd - CustomCommandManager.StartOffset].functions; + var functions = functionsState.GetCustomCommandFunctions((ushort)cmd); var expiration = input.arg1; if (expiration == -1) { @@ -583,10 +582,10 @@ public bool NeedCopyUpdate(ref SpanByte key, ref RawStringInput input, ref SpanB } return true; default: - if ((ushort)input.header.cmd >= CustomCommandManager.StartOffset) + if (input.header.cmd > RespCommandExtensions.LastValidCommand) { (IMemoryOwner Memory, int Length) outp = (output.Memory, 0); - var ret = functionsState.customCommands[(ushort)input.header.cmd - CustomCommandManager.StartOffset].functions + var ret = functionsState.GetCustomCommandFunctions((ushort)input.header.cmd) .NeedCopyUpdate(key.AsReadOnlySpan(), ref input, oldValue.AsReadOnlySpan(), ref outp); output.Memory = outp.Memory; output.Length = outp.Length; @@ -818,9 +817,9 @@ public bool CopyUpdater(ref SpanByte key, ref RawStringInput input, ref SpanByte break; default: - if ((ushort)input.header.cmd >= CustomCommandManager.StartOffset) + if (input.header.cmd > RespCommandExtensions.LastValidCommand) { - var functions = functionsState.customCommands[(ushort)input.header.cmd - CustomCommandManager.StartOffset].functions; + var functions = functionsState.GetCustomCommandFunctions((ushort)input.header.cmd); var expiration = input.arg1; if (expiration == 0) { diff --git a/libs/server/Storage/Functions/MainStore/ReadMethods.cs b/libs/server/Storage/Functions/MainStore/ReadMethods.cs index cd0a0be785..5447708a09 100644 --- a/libs/server/Storage/Functions/MainStore/ReadMethods.cs +++ b/libs/server/Storage/Functions/MainStore/ReadMethods.cs @@ -19,11 +19,11 @@ public bool SingleReader(ref SpanByte key, ref RawStringInput input, ref SpanByt return false; var cmd = input.header.cmd; - if ((ushort)cmd >= CustomCommandManager.StartOffset) + if (cmd > RespCommandExtensions.LastValidCommand) { var valueLength = value.LengthWithoutMetadata; (IMemoryOwner Memory, int Length) output = (dst.Memory, 0); - var ret = functionsState.customCommands[(ushort)cmd - CustomCommandManager.StartOffset].functions + var ret = functionsState.GetCustomCommandFunctions((ushort)cmd) .Reader(key.AsReadOnlySpan(), ref input, value.AsReadOnlySpan(), ref output, ref readInfo); Debug.Assert(valueLength <= value.LengthWithoutMetadata); dst.Memory = output.Memory; @@ -50,11 +50,11 @@ public bool ConcurrentReader(ref SpanByte key, ref RawStringInput input, ref Spa } var cmd = input.header.cmd; - if ((ushort)cmd >= CustomCommandManager.StartOffset) + if (cmd > RespCommandExtensions.LastValidCommand) { var valueLength = value.LengthWithoutMetadata; (IMemoryOwner Memory, int Length) output = (dst.Memory, 0); - var ret = functionsState.customCommands[(ushort)cmd - CustomCommandManager.StartOffset].functions + var ret = functionsState.GetCustomCommandFunctions((ushort)cmd) .Reader(key.AsReadOnlySpan(), ref input, value.AsReadOnlySpan(), ref output, ref readInfo); Debug.Assert(valueLength <= value.LengthWithoutMetadata); dst.Memory = output.Memory; diff --git a/libs/server/Storage/Functions/MainStore/VarLenInputMethods.cs b/libs/server/Storage/Functions/MainStore/VarLenInputMethods.cs index b0b3803465..442cf7a769 100644 --- a/libs/server/Storage/Functions/MainStore/VarLenInputMethods.cs +++ b/libs/server/Storage/Functions/MainStore/VarLenInputMethods.cs @@ -119,9 +119,9 @@ public int GetRMWInitialValueLength(ref RawStringInput input) return sizeof(int) + ndigits; default: - if ((ushort)cmd >= CustomCommandManager.StartOffset) + if (cmd > RespCommandExtensions.LastValidCommand) { - var functions = functionsState.customCommands[(ushort)cmd - CustomCommandManager.StartOffset].functions; + var functions = functionsState.GetCustomCommandFunctions((ushort)cmd); // Compute metadata size for result int metadataSize = input.arg1 switch { @@ -236,9 +236,9 @@ public int GetRMWModifiedValueLength(ref SpanByte t, ref RawStringInput input) return sizeof(int) + t.Length + valueLength; default: - if ((ushort)cmd >= CustomCommandManager.StartOffset) + if (cmd > RespCommandExtensions.LastValidCommand) { - var functions = functionsState.customCommands[(ushort)cmd - CustomCommandManager.StartOffset].functions; + var functions = functionsState.GetCustomCommandFunctions((ushort)cmd); // compute metadata for result var metadataSize = input.arg1 switch { diff --git a/libs/server/Storage/Functions/ObjectStore/PrivateMethods.cs b/libs/server/Storage/Functions/ObjectStore/PrivateMethods.cs index 91698a8f00..88f08e9d53 100644 --- a/libs/server/Storage/Functions/ObjectStore/PrivateMethods.cs +++ b/libs/server/Storage/Functions/ObjectStore/PrivateMethods.cs @@ -184,9 +184,8 @@ static bool EvaluateObjectExpireInPlace(ExpireOption optionType, bool expiryExis [MethodImpl(MethodImplOptions.AggressiveInlining)] private CustomObjectFunctions GetCustomObjectCommand(ref ObjectInput input, GarnetObjectType type) { - var objectId = (byte)((byte)type - CustomCommandManager.TypeIdStartOffset); var cmdId = input.header.SubId; - var customObjectCommand = functionsState.customObjectCommands[objectId].commandMap[cmdId].functions; + var customObjectCommand = functionsState.GetCustomObjectSubCommandFunctions((byte)type, cmdId); return customObjectCommand; } diff --git a/libs/server/Storage/Functions/ObjectStore/RMWMethods.cs b/libs/server/Storage/Functions/ObjectStore/RMWMethods.cs index 8a28bc1e1e..01d8c562bb 100644 --- a/libs/server/Storage/Functions/ObjectStore/RMWMethods.cs +++ b/libs/server/Storage/Functions/ObjectStore/RMWMethods.cs @@ -55,8 +55,7 @@ public bool InitialUpdater(ref byte[] key, ref ObjectInput input, ref IGarnetObj Debug.Assert(type != GarnetObjectType.Expire && type != GarnetObjectType.PExpire && type != GarnetObjectType.Persist, "Expire and Persist commands should have been handled already by NeedInitialUpdate."); var customObjectCommand = GetCustomObjectCommand(ref input, type); - var objectId = (byte)((byte)type - CustomCommandManager.TypeIdStartOffset); - value = functionsState.customObjectCommands[objectId].factory.Create((byte)type); + value = functionsState.GetCustomObjectFactory((byte)type).Create((byte)type); (IMemoryOwner Memory, int Length) outp = (output.spanByteAndMemory.Memory, 0); var result = customObjectCommand.InitialUpdater(key, ref input, value, ref outp, ref rmwInfo); diff --git a/libs/server/StoreWrapper.cs b/libs/server/StoreWrapper.cs index b94833916f..49b30eff39 100644 --- a/libs/server/StoreWrapper.cs +++ b/libs/server/StoreWrapper.cs @@ -100,7 +100,7 @@ public sealed class StoreWrapper private SingleWriterMultiReaderLock _checkpointTaskLock; // Lua script cache - public readonly ConcurrentDictionary storeScriptCache; + public readonly ConcurrentDictionary storeScriptCache; public readonly TimeSpan loggingFrequncy; @@ -153,7 +153,7 @@ public StoreWrapper( // Initialize store scripting cache if (serverOptions.EnableLua) - this.storeScriptCache = new ConcurrentDictionary(new ByteArrayComparer()); + this.storeScriptCache = new(SpanByteAndMemoryComparer.Instance); if (accessControlList == null) { @@ -217,7 +217,7 @@ public string GetIp() } internal FunctionsState CreateFunctionsState() - => new(appendOnlyFile, versionMap, customCommandManager.rawStringCommandMap, customCommandManager.objectCommandMap, null, objectStoreSizeTracker, GarnetObjectSerializer); + => new(appendOnlyFile, versionMap, customCommandManager, null, objectStoreSizeTracker, GarnetObjectSerializer); internal void Recover() { diff --git a/libs/server/Transaction/TxnRespCommands.cs b/libs/server/Transaction/TxnRespCommands.cs index 14186a0537..e18d97a7c7 100644 --- a/libs/server/Transaction/TxnRespCommands.cs +++ b/libs/server/Transaction/TxnRespCommands.cs @@ -266,7 +266,7 @@ private bool NetworkRUNTXP() try { - (proc, arity) = customCommandManagerSession.GetCustomTransactionProcedure(txId, this, txnManager, scratchBufferManager); + proc = customCommandManagerSession.GetCustomTransactionProcedure(txId, this, txnManager, scratchBufferManager, out arity); } catch (Exception e) { diff --git a/libs/storage/Tsavorite/cs/src/core/Allocator/AllocatorScan.cs b/libs/storage/Tsavorite/cs/src/core/Allocator/AllocatorScan.cs index 213b67956c..fc732747b3 100644 --- a/libs/storage/Tsavorite/cs/src/core/Allocator/AllocatorScan.cs +++ b/libs/storage/Tsavorite/cs/src/core/Allocator/AllocatorScan.cs @@ -186,25 +186,28 @@ internal unsafe bool GetFromDiskAndPushToReader(ref TKey key, re /// Currently we load an entire page, which while inefficient in performance, allows us to make the cursor safe (by ensuring we align to a valid record) if it is not /// the last one returned. We could optimize this to load only the subset of a page that is pointed to by the cursor and do GetRequiredRecordSize/RetrievedFullRecord as in /// AsyncGetFromDiskCallback. However, this would not validate the cursor and would therefore require maintaining a cursor history. - internal abstract bool ScanCursor(TsavoriteKV store, ScanCursorState scanCursorState, ref long cursor, long count, TScanFunctions scanFunctions, long endAddress, bool validateCursor) + internal abstract bool ScanCursor(TsavoriteKV store, ScanCursorState scanCursorState, ref long cursor, long count, TScanFunctions scanFunctions, long endAddress, bool validateCursor, long maxAddress) where TScanFunctions : IScanIteratorFunctions; private protected bool ScanLookup(TsavoriteKV store, - ScanCursorState scanCursorState, ref long cursor, long count, TScanFunctions scanFunctions, TScanIterator iter, bool validateCursor) + ScanCursorState scanCursorState, ref long cursor, long count, TScanFunctions scanFunctions, TScanIterator iter, bool validateCursor, long maxAddress) where TScanFunctions : IScanIteratorFunctions where TScanIterator : ITsavoriteScanIterator, IPushScanIterator { using var session = store.NewSession>(new LogScanCursorFunctions()); var bContext = session.BasicContext; - if (cursor >= GetTailAddress()) - goto IterationComplete; - if (cursor < BeginAddress) // This includes 0, which means to start the Scan cursor = BeginAddress; else if (validateCursor) iter.SnapCursorToLogicalAddress(ref cursor); + if (!scanFunctions.OnStart(cursor, iter.EndAddress)) + return false; + + if (cursor >= GetTailAddress()) + goto IterationComplete; + scanCursorState.Initialize(scanFunctions); long numPending = 0; @@ -214,7 +217,7 @@ private protected bool ScanLookup= count || scanCursorState.endBatch) + { + scanFunctions.OnStop(true, scanCursorState.acceptedCount); return true; + } } // Drain any pending pushes. We have ended the iteration; there are no more records, so drop through to end it. @@ -242,12 +251,13 @@ private protected bool ScanLookup(TSessionFunctionsWrapper sessionFunctions, ScanCursorState scanCursorState, RecordInfo recordInfo, - ref TKey key, ref TValue value, long currentAddress, long minAddress) + ref TKey key, ref TValue value, long currentAddress, long minAddress, long maxAddress) where TSessionFunctionsWrapper : ISessionFunctionsWrapper { Debug.Assert(epoch.ThisInstanceProtected(), "This is called only from ScanLookup so the epoch should be protected"); @@ -259,7 +269,7 @@ internal Status ConditionalScanPush(sessionFunctions, ref key, ref stackCtx, currentAddress, minAddress, out internalStatus, out needIO)) + if (sessionFunctions.Store.TryFindRecordInMainLogForConditionalOperation(sessionFunctions, ref key, ref stackCtx, currentAddress, minAddress, maxAddress, out internalStatus, out needIO)) return Status.CreateFound(); } while (sessionFunctions.Store.HandleImmediateNonPendingRetryStatus(internalStatus, sessionFunctions)); @@ -270,7 +280,7 @@ internal Status ConditionalScanPush(TSessionFunctionsWrapper sessionFunctions, ref TsavoriteKV.PendingContext pendingContext, ref TKey key, ref TInput input, ref TValue value, ref TOutput output, TContext userContext, - ref OperationStackContext stackCtx, long minAddress, ScanCursorState scanCursorState) + ref OperationStackContext stackCtx, long minAddress, long maxAddress, ScanCursorState scanCursorState) where TSessionFunctionsWrapper : ISessionFunctionsWrapper { // WriteReason is not surfaced for this operation, so pick anything. var status = sessionFunctions.Store.PrepareIOForConditionalOperation(sessionFunctions, ref pendingContext, ref key, ref input, ref value, ref output, - userContext, ref stackCtx, minAddress, WriteReason.Compaction, OperationType.CONDITIONAL_SCAN_PUSH); + userContext, ref stackCtx, minAddress, maxAddress, WriteReason.Compaction, OperationType.CONDITIONAL_SCAN_PUSH); pendingContext.scanCursorState = scanCursorState; return status; } diff --git a/libs/storage/Tsavorite/cs/src/core/Allocator/BlittableAllocatorImpl.cs b/libs/storage/Tsavorite/cs/src/core/Allocator/BlittableAllocatorImpl.cs index 1e6763c6c0..d60df042f4 100644 --- a/libs/storage/Tsavorite/cs/src/core/Allocator/BlittableAllocatorImpl.cs +++ b/libs/storage/Tsavorite/cs/src/core/Allocator/BlittableAllocatorImpl.cs @@ -273,10 +273,10 @@ internal override bool Scan(TsavoriteKV internal override bool ScanCursor(TsavoriteKV> store, - ScanCursorState scanCursorState, ref long cursor, long count, TScanFunctions scanFunctions, long endAddress, bool validateCursor) + ScanCursorState scanCursorState, ref long cursor, long count, TScanFunctions scanFunctions, long endAddress, bool validateCursor, long maxAddress) { using BlittableScanIterator iter = new(store, this, cursor, endAddress, ScanBufferingMode.SinglePageBuffering, false, epoch, logger: logger); - return ScanLookup>(store, scanCursorState, ref cursor, count, scanFunctions, iter, validateCursor); + return ScanLookup>(store, scanCursorState, ref cursor, count, scanFunctions, iter, validateCursor, maxAddress); } /// diff --git a/libs/storage/Tsavorite/cs/src/core/Allocator/GenericAllocatorImpl.cs b/libs/storage/Tsavorite/cs/src/core/Allocator/GenericAllocatorImpl.cs index 9c5a7b6b9c..faf50f1aa4 100644 --- a/libs/storage/Tsavorite/cs/src/core/Allocator/GenericAllocatorImpl.cs +++ b/libs/storage/Tsavorite/cs/src/core/Allocator/GenericAllocatorImpl.cs @@ -1013,10 +1013,10 @@ internal override bool Scan(TsavoriteKV internal override bool ScanCursor(TsavoriteKV> store, - ScanCursorState scanCursorState, ref long cursor, long count, TScanFunctions scanFunctions, long endAddress, bool validateCursor) + ScanCursorState scanCursorState, ref long cursor, long count, TScanFunctions scanFunctions, long endAddress, bool validateCursor, long maxAddress) { using GenericScanIterator iter = new(store, this, cursor, endAddress, ScanBufferingMode.SinglePageBuffering, false, epoch, logger: logger); - return ScanLookup>(store, scanCursorState, ref cursor, count, scanFunctions, iter, validateCursor); + return ScanLookup>(store, scanCursorState, ref cursor, count, scanFunctions, iter, validateCursor, maxAddress); } /// diff --git a/libs/storage/Tsavorite/cs/src/core/Allocator/IScanIteratorFunctions.cs b/libs/storage/Tsavorite/cs/src/core/Allocator/IScanIteratorFunctions.cs index fb1ca34f9d..e163c6b72d 100644 --- a/libs/storage/Tsavorite/cs/src/core/Allocator/IScanIteratorFunctions.cs +++ b/libs/storage/Tsavorite/cs/src/core/Allocator/IScanIteratorFunctions.cs @@ -24,7 +24,12 @@ public enum CursorRecordResult /// /// End the current cursor batch (as if "count" had been met); return a valid cursor for the next ScanCursor call /// - EndBatch = 4 + EndBatch = 4, + + /// + /// Retry the last record when returning a valid cursor + /// + RetryLastRecord = 8, } /// @@ -42,7 +47,7 @@ public interface IScanIteratorFunctions /// Reference to the current record's key /// Reference to the current record's Value /// Record metadata, including and the current record's logical address - /// The number of records returned so far, including the current one. + /// The number of records accepted so far, not including the current one. /// Indicates whether the current record was accepted, or whether to end the current ScanCursor call. /// Ignored for non-cursor Scans; set to . /// True to continue iteration, else false @@ -52,7 +57,7 @@ public interface IScanIteratorFunctions /// Reference to the current record's key /// Reference to the current record's Value /// Record metadata, including and the current record's logical address - /// The number of records returned so far, including the current one. + /// The number of records accepted so far, not including the current one. /// Indicates whether the current record was accepted, or whether to end the current ScanCursor call. /// Ignored for non-cursor Scans; set to . /// True to continue iteration, else false diff --git a/libs/storage/Tsavorite/cs/src/core/Allocator/IStreamingSnapshotIteratorFunctions.cs b/libs/storage/Tsavorite/cs/src/core/Allocator/IStreamingSnapshotIteratorFunctions.cs new file mode 100644 index 0000000000..cd95e696be --- /dev/null +++ b/libs/storage/Tsavorite/cs/src/core/Allocator/IStreamingSnapshotIteratorFunctions.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; + +namespace Tsavorite.core +{ + /// + /// Callback functions for streaming snapshot iteration + /// + public interface IStreamingSnapshotIteratorFunctions + { + /// Iteration is starting. + /// Checkpoint token + /// Current version of database + /// Target version of database + /// True to continue iteration, else false + bool OnStart(Guid checkpointToken, long currentVersion, long targetVersion); + + /// Next record in the streaming snapshot. + /// Reference to the current record's key + /// Reference to the current record's Value + /// Record metadata, including and the current record's logical address + /// The number of records returned so far, not including the current one. + /// True to continue iteration, else false + bool Reader(ref TKey key, ref TValue value, RecordMetadata recordMetadata, long numberOfRecords); + + /// Iteration is complete. + /// If true, the iteration completed; else OnStart() or Reader() returned false to stop the iteration. + /// The number of records returned before the iteration stopped. + void OnStop(bool completed, long numberOfRecords); + + /// An exception was thrown on iteration (likely during . + /// The exception that was thrown. + /// The number of records returned before the exception. + void OnException(Exception exception, long numberOfRecords); + } +} \ No newline at end of file diff --git a/libs/storage/Tsavorite/cs/src/core/Allocator/ScanCursorState.cs b/libs/storage/Tsavorite/cs/src/core/Allocator/ScanCursorState.cs index ce53b1474a..eeddbebc45 100644 --- a/libs/storage/Tsavorite/cs/src/core/Allocator/ScanCursorState.cs +++ b/libs/storage/Tsavorite/cs/src/core/Allocator/ScanCursorState.cs @@ -6,15 +6,17 @@ namespace Tsavorite.core internal sealed class ScanCursorState { internal IScanIteratorFunctions functions; - internal long acceptedCount; // Number of records pushed to and accepted by the caller - internal bool endBatch; // End the batch (but return a valid cursor for the next batch, as of "count" records had been returned) - internal bool stop; // Stop the operation (as if all records in the db had been returned) + internal long acceptedCount; // Number of records pushed to and accepted by the caller + internal bool endBatch; // End the batch (but return a valid cursor for the next batch, as if "count" records had been returned) + internal bool retryLastRecord; // Retry the last record when returning a valid cursor + internal bool stop; // Stop the operation (as if all records in the db had been returned) internal void Initialize(IScanIteratorFunctions scanIteratorFunctions) { functions = scanIteratorFunctions; acceptedCount = 0; endBatch = false; + retryLastRecord = false; stop = false; } } diff --git a/libs/storage/Tsavorite/cs/src/core/Allocator/SpanByteAllocatorImpl.cs b/libs/storage/Tsavorite/cs/src/core/Allocator/SpanByteAllocatorImpl.cs index e25ac476a7..f01495c1fd 100644 --- a/libs/storage/Tsavorite/cs/src/core/Allocator/SpanByteAllocatorImpl.cs +++ b/libs/storage/Tsavorite/cs/src/core/Allocator/SpanByteAllocatorImpl.cs @@ -359,10 +359,10 @@ internal override bool Scan(TsavoriteKV internal override bool ScanCursor(TsavoriteKV> store, - ScanCursorState scanCursorState, ref long cursor, long count, TScanFunctions scanFunctions, long endAddress, bool validateCursor) + ScanCursorState scanCursorState, ref long cursor, long count, TScanFunctions scanFunctions, long endAddress, bool validateCursor, long maxAddress) { using SpanByteScanIterator iter = new(store, this, cursor, endAddress, ScanBufferingMode.SinglePageBuffering, false, epoch, logger: logger); - return ScanLookup>(store, scanCursorState, ref cursor, count, scanFunctions, iter, validateCursor); + return ScanLookup>(store, scanCursorState, ref cursor, count, scanFunctions, iter, validateCursor, maxAddress); } /// diff --git a/libs/storage/Tsavorite/cs/src/core/ClientSession/BasicContext.cs b/libs/storage/Tsavorite/cs/src/core/ClientSession/BasicContext.cs index 8f610740dd..1f791debe1 100644 --- a/libs/storage/Tsavorite/cs/src/core/ClientSession/BasicContext.cs +++ b/libs/storage/Tsavorite/cs/src/core/ClientSession/BasicContext.cs @@ -456,13 +456,13 @@ internal Status CompactionCopyToTail(ref TKey key, ref TInput input, ref TValue /// LogicalAddress of the record to be copied /// Lower-bound address (addresses are searched from tail (high) to head (low); do not search for "future records" earlier than this) [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal Status ConditionalScanPush(ScanCursorState scanCursorState, RecordInfo recordInfo, ref TKey key, ref TValue value, long currentAddress, long untilAddress) + internal Status ConditionalScanPush(ScanCursorState scanCursorState, RecordInfo recordInfo, ref TKey key, ref TValue value, long currentAddress, long untilAddress, long maxAddress) { UnsafeResumeThread(); try { return store.hlogBase.ConditionalScanPush, TStoreFunctions, TAllocator>>( - sessionFunctions, scanCursorState, recordInfo, ref key, ref value, currentAddress, untilAddress); + sessionFunctions, scanCursorState, recordInfo, ref key, ref value, currentAddress, untilAddress, maxAddress); } finally { diff --git a/libs/storage/Tsavorite/cs/src/core/ClientSession/ClientSession.cs b/libs/storage/Tsavorite/cs/src/core/ClientSession/ClientSession.cs index adc297ba11..38df35ab13 100644 --- a/libs/storage/Tsavorite/cs/src/core/ClientSession/ClientSession.cs +++ b/libs/storage/Tsavorite/cs/src/core/ClientSession/ClientSession.cs @@ -503,13 +503,14 @@ public bool IterateLookup(ref TScanFunctions scanFunctions, long /// the pending IO process. /// A specific end address; otherwise we scan until we hit the current TailAddress, which may yield duplicates in the event of RCUs. /// This may be set to the TailAddress at the start of the scan, which may lose records that are RCU'd during the scan (because they are moved above the starting - /// TailAddress). A snapshot can be taken by calling ShiftReadOnlyToTail() and then using that TailAddress as endAddress. + /// TailAddress). A snapshot can be taken by calling ShiftReadOnlyToTail() and then using that TailAddress as endAddress and maxAddress. /// If true, validate that the cursor is on a valid address boundary, and snap it to the highest lower address if it is not. + /// Maximum address for determining liveness, records after this address are not considered when checking validity. /// True if Scan completed and pushed records; false if Scan ended early due to finding less than records /// or one of the TScanIterator reader functions returning false - public bool ScanCursor(ref long cursor, long count, TScanFunctions scanFunctions, long endAddress = long.MaxValue, bool validateCursor = false) + public bool ScanCursor(ref long cursor, long count, TScanFunctions scanFunctions, long endAddress = long.MaxValue, bool validateCursor = false, long maxAddress = long.MaxValue) where TScanFunctions : IScanIteratorFunctions - => store.hlogBase.ScanCursor(store, scanCursorState ??= new(), ref cursor, count, scanFunctions, endAddress, validateCursor); + => store.hlogBase.ScanCursor(store, scanCursorState ??= new(), ref cursor, count, scanFunctions, endAddress, validateCursor, maxAddress); /// /// Resume session on current thread. IMPORTANT: Call SuspendThread before any async op. diff --git a/libs/storage/Tsavorite/cs/src/core/Index/CheckpointManagement/RecoveryInfo.cs b/libs/storage/Tsavorite/cs/src/core/Index/CheckpointManagement/RecoveryInfo.cs index 32897fa21f..34dd5874ab 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/CheckpointManagement/RecoveryInfo.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/CheckpointManagement/RecoveryInfo.cs @@ -48,7 +48,7 @@ public struct HybridLogRecoveryInfo /// public long finalLogicalAddress; /// - /// Snapshot end logical address: snaphot is [startLogicalAddress, snapshotFinalLogicalAddress) + /// Snapshot end logical address: snapshot is [startLogicalAddress, snapshotFinalLogicalAddress) /// Note that finalLogicalAddress may be higher due to delta records /// public long snapshotFinalLogicalAddress; diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Common/CheckpointSettings.cs b/libs/storage/Tsavorite/cs/src/core/Index/Common/CheckpointSettings.cs index c50261a3d2..ac12e9cb0f 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/Common/CheckpointSettings.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/Common/CheckpointSettings.cs @@ -17,7 +17,12 @@ public enum CheckpointType /// Flush current log (move read-only to tail) /// (enables incremental checkpointing, but log grows faster) /// - FoldOver + FoldOver, + + /// + /// Yield a stream of key-value records in version (v), that can be used to rebuild the store + /// + StreamingSnapshot, } /// diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Common/PendingContext.cs b/libs/storage/Tsavorite/cs/src/core/Index/Common/PendingContext.cs index 0f4fbe7fe8..bcb1063712 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/Common/PendingContext.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/Common/PendingContext.cs @@ -37,6 +37,7 @@ internal struct PendingContext internal RecordInfo recordInfo; internal long minAddress; + internal long maxAddress; // For flushing head pages on tail allocation. internal CompletionEvent flushEvent; diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Recovery/Recovery.cs b/libs/storage/Tsavorite/cs/src/core/Index/Recovery/Recovery.cs index 10c5717730..db85d4fa0e 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/Recovery/Recovery.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/Recovery/Recovery.cs @@ -538,6 +538,16 @@ private void DoPostRecovery(IndexCheckpointInfo recoveredICInfo, HybridLogCheckp recoveredHLCInfo.Dispose(); } + /// + /// Set store version directly. Useful if manually recovering by re-inserting data. + /// Warning: use only when the system is not taking a checkpoint. + /// + /// Version to set the store to + public void SetVersion(long version) + { + systemState = SystemState.Make(Phase.REST, version); + } + /// /// Compute recovery address and determine where to recover to /// diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/FoldOverCheckpointTask.cs b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/FoldOverCheckpointTask.cs new file mode 100644 index 0000000000..e8f77ae21e --- /dev/null +++ b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/FoldOverCheckpointTask.cs @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Tsavorite.core +{ + /// + /// A FoldOver checkpoint persists a version by setting the read-only marker past the last entry of that + /// version on the log and waiting until it is flushed to disk. It is simple and fast, but can result + /// in garbage entries on the log, and a slower recovery of performance. + /// + internal sealed class FoldOverCheckpointTask : HybridLogCheckpointOrchestrationTask + where TStoreFunctions : IStoreFunctions + where TAllocator : IAllocator + { + /// + public override void GlobalBeforeEnteringState(SystemState next, + TsavoriteKV store) + { + base.GlobalBeforeEnteringState(next, store); + + if (next.Phase == Phase.PREPARE) + { + store._lastSnapshotCheckpoint.Dispose(); + } + + if (next.Phase == Phase.IN_PROGRESS) + base.GlobalBeforeEnteringState(next, store); + + if (next.Phase != Phase.WAIT_FLUSH) return; + + _ = store.hlogBase.ShiftReadOnlyToTail(out var tailAddress, out store._hybridLogCheckpoint.flushedSemaphore); + store._hybridLogCheckpoint.info.finalLogicalAddress = tailAddress; + } + + /// + public override void OnThreadState( + SystemState current, + SystemState prev, + TsavoriteKV store, + TsavoriteKV.TsavoriteExecutionContext ctx, + TSessionFunctionsWrapper sessionFunctions, + List valueTasks, + CancellationToken token = default) + { + base.OnThreadState(current, prev, store, ctx, sessionFunctions, valueTasks, token); + + if (current.Phase != Phase.WAIT_FLUSH) return; + + if (ctx is null || !ctx.prevCtx.markers[EpochPhaseIdx.WaitFlush]) + { + var s = store._hybridLogCheckpoint.flushedSemaphore; + + var notify = store.hlogBase.FlushedUntilAddress >= store._hybridLogCheckpoint.info.finalLogicalAddress; + notify = notify || !store.SameCycle(ctx, current) || s == null; + + if (valueTasks != null && !notify) + { + valueTasks.Add(new ValueTask(s.WaitAsync(token).ContinueWith(t => s.Release()))); + } + + if (!notify) return; + + if (ctx is not null) + ctx.prevCtx.markers[EpochPhaseIdx.WaitFlush] = true; + } + + store.epoch.Mark(EpochPhaseIdx.WaitFlush, current.Version); + if (store.epoch.CheckIsComplete(EpochPhaseIdx.WaitFlush, current.Version)) + store.GlobalStateMachineStep(current); + } + } +} \ No newline at end of file diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/HybridLogCheckpointOrchestrationTask.cs b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/HybridLogCheckpointOrchestrationTask.cs new file mode 100644 index 0000000000..762468ad92 --- /dev/null +++ b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/HybridLogCheckpointOrchestrationTask.cs @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Tsavorite.core +{ + /// + /// This task is the base class for a checkpoint "backend", which decides how a captured version is + /// persisted on disk. + /// + internal abstract class HybridLogCheckpointOrchestrationTask : ISynchronizationTask + where TStoreFunctions : IStoreFunctions + where TAllocator : IAllocator + { + private long lastVersion; + /// + public virtual void GlobalBeforeEnteringState(SystemState next, + TsavoriteKV store) + { + switch (next.Phase) + { + case Phase.PREPARE: + lastVersion = store.systemState.Version; + if (store._hybridLogCheckpoint.IsDefault()) + { + store._hybridLogCheckpointToken = Guid.NewGuid(); + store.InitializeHybridLogCheckpoint(store._hybridLogCheckpointToken, next.Version); + } + store._hybridLogCheckpoint.info.version = next.Version; + store._hybridLogCheckpoint.info.startLogicalAddress = store.hlogBase.GetTailAddress(); + // Capture begin address before checkpoint starts + store._hybridLogCheckpoint.info.beginAddress = store.hlogBase.BeginAddress; + break; + case Phase.IN_PROGRESS: + store.CheckpointVersionShift(lastVersion, next.Version); + break; + case Phase.WAIT_FLUSH: + store._hybridLogCheckpoint.info.headAddress = store.hlogBase.HeadAddress; + store._hybridLogCheckpoint.info.nextVersion = next.Version; + break; + case Phase.PERSISTENCE_CALLBACK: + CollectMetadata(next, store); + store.WriteHybridLogMetaInfo(); + store.lastVersion = lastVersion; + break; + case Phase.REST: + store._hybridLogCheckpoint.Dispose(); + var nextTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + store.checkpointTcs.SetResult(new LinkedCheckpointInfo { NextTask = nextTcs.Task }); + store.checkpointTcs = nextTcs; + break; + } + } + + protected static void CollectMetadata(SystemState next, TsavoriteKV store) + { + // Collect object log offsets only after flushes + // are completed + var seg = store.hlog.GetSegmentOffsets(); + if (seg != null) + { + store._hybridLogCheckpoint.info.objectLogSegmentOffsets = new long[seg.Length]; + Array.Copy(seg, store._hybridLogCheckpoint.info.objectLogSegmentOffsets, seg.Length); + } + + // Temporarily block new sessions from starting, which may add an entry to the table and resize the + // dictionary. There should be minimal contention here. + lock (store._activeSessions) + { + List toDelete = null; + + // write dormant sessions to checkpoint + foreach (var kvp in store._activeSessions) + { + kvp.Value.session.AtomicSwitch(next.Version - 1); + if (!kvp.Value.isActive) + { + toDelete ??= new(); + toDelete.Add(kvp.Key); + } + } + + // delete any sessions that ended during checkpoint cycle + if (toDelete != null) + { + foreach (var key in toDelete) + _ = store._activeSessions.Remove(key); + } + } + } + + /// + public virtual void GlobalAfterEnteringState(SystemState next, + TsavoriteKV store) + { + } + + /// + public virtual void OnThreadState( + SystemState current, + SystemState prev, TsavoriteKV store, + TsavoriteKV.TsavoriteExecutionContext ctx, + TSessionFunctionsWrapper sessionFunctions, + List valueTasks, + CancellationToken token = default) + where TSessionFunctionsWrapper : ISessionEpochControl + { + if (current.Phase != Phase.PERSISTENCE_CALLBACK) + return; + + store.epoch.Mark(EpochPhaseIdx.CheckpointCompletionCallback, current.Version); + if (store.epoch.CheckIsComplete(EpochPhaseIdx.CheckpointCompletionCallback, current.Version)) + { + store.storeFunctions.OnCheckpointCompleted(); + store.GlobalStateMachineStep(current); + } + } + } +} \ No newline at end of file diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/HybridLogCheckpointStateMachine.cs b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/HybridLogCheckpointStateMachine.cs new file mode 100644 index 0000000000..36cd360610 --- /dev/null +++ b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/HybridLogCheckpointStateMachine.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +namespace Tsavorite.core +{ + /// + /// Hybrid log checkpoint state machine. + /// + internal class HybridLogCheckpointStateMachine : VersionChangeStateMachine + where TStoreFunctions : IStoreFunctions + where TAllocator : IAllocator + { + /// + /// Construct a new HybridLogCheckpointStateMachine to use the given checkpoint backend (either fold-over or + /// snapshot), drawing boundary at targetVersion. + /// + /// A task that encapsulates the logic to persist the checkpoint + /// upper limit (inclusive) of the version included + public HybridLogCheckpointStateMachine(ISynchronizationTask checkpointBackend, long targetVersion = -1) + : base(targetVersion, new VersionChangeTask(), checkpointBackend) { } + + /// + /// Construct a new HybridLogCheckpointStateMachine with the given tasks. Does not load any tasks by default. + /// + /// upper limit (inclusive) of the version included + /// The tasks to load onto the state machine + protected HybridLogCheckpointStateMachine(long targetVersion, params ISynchronizationTask[] tasks) + : base(targetVersion, tasks) { } + + /// + public override SystemState NextState(SystemState start) + { + var result = SystemState.Copy(ref start); + switch (start.Phase) + { + case Phase.IN_PROGRESS: + result.Phase = Phase.WAIT_FLUSH; + break; + case Phase.WAIT_FLUSH: + result.Phase = Phase.PERSISTENCE_CALLBACK; + break; + case Phase.PERSISTENCE_CALLBACK: + result.Phase = Phase.REST; + break; + default: + result = base.NextState(start); + break; + } + + return result; + } + } +} \ No newline at end of file diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/HybridLogCheckpointTask.cs b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/HybridLogCheckpointTask.cs deleted file mode 100644 index b751bf610d..0000000000 --- a/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/HybridLogCheckpointTask.cs +++ /dev/null @@ -1,442 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Threading; -using System.Threading.Tasks; - -namespace Tsavorite.core -{ - /// - /// This task is the base class for a checkpoint "backend", which decides how a captured version is - /// persisted on disk. - /// - internal abstract class HybridLogCheckpointOrchestrationTask : ISynchronizationTask - where TStoreFunctions : IStoreFunctions - where TAllocator : IAllocator - { - private long lastVersion; - /// - public virtual void GlobalBeforeEnteringState(SystemState next, - TsavoriteKV store) - { - switch (next.Phase) - { - case Phase.PREPARE: - lastVersion = store.systemState.Version; - if (store._hybridLogCheckpoint.IsDefault()) - { - store._hybridLogCheckpointToken = Guid.NewGuid(); - store.InitializeHybridLogCheckpoint(store._hybridLogCheckpointToken, next.Version); - } - store._hybridLogCheckpoint.info.version = next.Version; - store._hybridLogCheckpoint.info.startLogicalAddress = store.hlogBase.GetTailAddress(); - // Capture begin address before checkpoint starts - store._hybridLogCheckpoint.info.beginAddress = store.hlogBase.BeginAddress; - break; - case Phase.IN_PROGRESS: - store.CheckpointVersionShift(lastVersion, next.Version); - break; - case Phase.WAIT_FLUSH: - store._hybridLogCheckpoint.info.headAddress = store.hlogBase.HeadAddress; - store._hybridLogCheckpoint.info.nextVersion = next.Version; - break; - case Phase.PERSISTENCE_CALLBACK: - CollectMetadata(next, store); - store.WriteHybridLogMetaInfo(); - store.lastVersion = lastVersion; - break; - case Phase.REST: - store._hybridLogCheckpoint.Dispose(); - var nextTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - store.checkpointTcs.SetResult(new LinkedCheckpointInfo { NextTask = nextTcs.Task }); - store.checkpointTcs = nextTcs; - break; - } - } - - protected static void CollectMetadata(SystemState next, TsavoriteKV store) - { - // Collect object log offsets only after flushes - // are completed - var seg = store.hlog.GetSegmentOffsets(); - if (seg != null) - { - store._hybridLogCheckpoint.info.objectLogSegmentOffsets = new long[seg.Length]; - Array.Copy(seg, store._hybridLogCheckpoint.info.objectLogSegmentOffsets, seg.Length); - } - - // Temporarily block new sessions from starting, which may add an entry to the table and resize the - // dictionary. There should be minimal contention here. - lock (store._activeSessions) - { - List toDelete = null; - - // write dormant sessions to checkpoint - foreach (var kvp in store._activeSessions) - { - kvp.Value.session.AtomicSwitch(next.Version - 1); - if (!kvp.Value.isActive) - { - toDelete ??= new(); - toDelete.Add(kvp.Key); - } - } - - // delete any sessions that ended during checkpoint cycle - if (toDelete != null) - { - foreach (var key in toDelete) - _ = store._activeSessions.Remove(key); - } - } - } - - /// - public virtual void GlobalAfterEnteringState(SystemState next, - TsavoriteKV store) - { - } - - /// - public virtual void OnThreadState( - SystemState current, - SystemState prev, TsavoriteKV store, - TsavoriteKV.TsavoriteExecutionContext ctx, - TSessionFunctionsWrapper sessionFunctions, - List valueTasks, - CancellationToken token = default) - where TSessionFunctionsWrapper : ISessionEpochControl - { - if (current.Phase != Phase.PERSISTENCE_CALLBACK) - return; - - store.epoch.Mark(EpochPhaseIdx.CheckpointCompletionCallback, current.Version); - if (store.epoch.CheckIsComplete(EpochPhaseIdx.CheckpointCompletionCallback, current.Version)) - { - store.storeFunctions.OnCheckpointCompleted(); - store.GlobalStateMachineStep(current); - } - } - } - - /// - /// A FoldOver checkpoint persists a version by setting the read-only marker past the last entry of that - /// version on the log and waiting until it is flushed to disk. It is simple and fast, but can result - /// in garbage entries on the log, and a slower recovery of performance. - /// - internal sealed class FoldOverCheckpointTask : HybridLogCheckpointOrchestrationTask - where TStoreFunctions : IStoreFunctions - where TAllocator : IAllocator - { - /// - public override void GlobalBeforeEnteringState(SystemState next, - TsavoriteKV store) - { - base.GlobalBeforeEnteringState(next, store); - - if (next.Phase == Phase.PREPARE) - { - store._lastSnapshotCheckpoint.Dispose(); - } - - if (next.Phase == Phase.IN_PROGRESS) - base.GlobalBeforeEnteringState(next, store); - - if (next.Phase != Phase.WAIT_FLUSH) return; - - _ = store.hlogBase.ShiftReadOnlyToTail(out var tailAddress, out store._hybridLogCheckpoint.flushedSemaphore); - store._hybridLogCheckpoint.info.finalLogicalAddress = tailAddress; - } - - /// - public override void OnThreadState( - SystemState current, - SystemState prev, - TsavoriteKV store, - TsavoriteKV.TsavoriteExecutionContext ctx, - TSessionFunctionsWrapper sessionFunctions, - List valueTasks, - CancellationToken token = default) - { - base.OnThreadState(current, prev, store, ctx, sessionFunctions, valueTasks, token); - - if (current.Phase != Phase.WAIT_FLUSH) return; - - if (ctx is null || !ctx.prevCtx.markers[EpochPhaseIdx.WaitFlush]) - { - var s = store._hybridLogCheckpoint.flushedSemaphore; - - var notify = store.hlogBase.FlushedUntilAddress >= store._hybridLogCheckpoint.info.finalLogicalAddress; - notify = notify || !store.SameCycle(ctx, current) || s == null; - - if (valueTasks != null && !notify) - { - valueTasks.Add(new ValueTask(s.WaitAsync(token).ContinueWith(t => s.Release()))); - } - - if (!notify) return; - - if (ctx is not null) - ctx.prevCtx.markers[EpochPhaseIdx.WaitFlush] = true; - } - - store.epoch.Mark(EpochPhaseIdx.WaitFlush, current.Version); - if (store.epoch.CheckIsComplete(EpochPhaseIdx.WaitFlush, current.Version)) - store.GlobalStateMachineStep(current); - } - } - - /// - /// A Snapshot persists a version by making a copy for every entry of that version separate from the log. It is - /// slower and more complex than a foldover, but more space-efficient on the log, and retains in-place - /// update performance as it does not advance the readonly marker unnecessarily. - /// - internal sealed class SnapshotCheckpointTask : HybridLogCheckpointOrchestrationTask - where TStoreFunctions : IStoreFunctions - where TAllocator : IAllocator - { - /// - public override void GlobalBeforeEnteringState(SystemState next, TsavoriteKV store) - { - switch (next.Phase) - { - case Phase.PREPARE: - store._lastSnapshotCheckpoint.Dispose(); - base.GlobalBeforeEnteringState(next, store); - store._hybridLogCheckpoint.info.useSnapshotFile = 1; - break; - case Phase.WAIT_FLUSH: - base.GlobalBeforeEnteringState(next, store); - store._hybridLogCheckpoint.info.finalLogicalAddress = store.hlogBase.GetTailAddress(); - store._hybridLogCheckpoint.info.snapshotFinalLogicalAddress = store._hybridLogCheckpoint.info.finalLogicalAddress; - - store._hybridLogCheckpoint.snapshotFileDevice = - store.checkpointManager.GetSnapshotLogDevice(store._hybridLogCheckpointToken); - store._hybridLogCheckpoint.snapshotFileObjectLogDevice = - store.checkpointManager.GetSnapshotObjectLogDevice(store._hybridLogCheckpointToken); - store._hybridLogCheckpoint.snapshotFileDevice.Initialize(store.hlogBase.GetSegmentSize()); - store._hybridLogCheckpoint.snapshotFileObjectLogDevice.Initialize(-1); - - // If we are using a NullDevice then storage tier is not enabled and FlushedUntilAddress may be ReadOnlyAddress; get all records in memory. - store._hybridLogCheckpoint.info.snapshotStartFlushedLogicalAddress = store.hlogBase.IsNullDevice ? store.hlogBase.HeadAddress : store.hlogBase.FlushedUntilAddress; - - long startPage = store.hlogBase.GetPage(store._hybridLogCheckpoint.info.snapshotStartFlushedLogicalAddress); - long endPage = store.hlogBase.GetPage(store._hybridLogCheckpoint.info.finalLogicalAddress); - if (store._hybridLogCheckpoint.info.finalLogicalAddress > - store.hlog.GetStartLogicalAddress(endPage)) - { - endPage++; - } - - // We are writing pages outside epoch protection, so callee should be able to - // handle corrupted or unexpected concurrent page changes during the flush, e.g., by - // resuming epoch protection if necessary. Correctness is not affected as we will - // only read safe pages during recovery. - store.hlogBase.AsyncFlushPagesToDevice( - startPage, - endPage, - store._hybridLogCheckpoint.info.finalLogicalAddress, - store._hybridLogCheckpoint.info.startLogicalAddress, - store._hybridLogCheckpoint.snapshotFileDevice, - store._hybridLogCheckpoint.snapshotFileObjectLogDevice, - out store._hybridLogCheckpoint.flushedSemaphore, - store.ThrottleCheckpointFlushDelayMs); - break; - case Phase.PERSISTENCE_CALLBACK: - // Set actual FlushedUntil to the latest possible data in main log that is on disk - // If we are using a NullDevice then storage tier is not enabled and FlushedUntilAddress may be ReadOnlyAddress; get all records in memory. - store._hybridLogCheckpoint.info.flushedLogicalAddress = store.hlogBase.IsNullDevice ? store.hlogBase.HeadAddress : store.hlogBase.FlushedUntilAddress; - base.GlobalBeforeEnteringState(next, store); - store._lastSnapshotCheckpoint = store._hybridLogCheckpoint.Transfer(); - break; - default: - base.GlobalBeforeEnteringState(next, store); - break; - } - } - - /// - public override void OnThreadState( - SystemState current, - SystemState prev, TsavoriteKV store, - TsavoriteKV.TsavoriteExecutionContext ctx, - TSessionFunctionsWrapper sessionFunctions, - List valueTasks, - CancellationToken token = default) - { - base.OnThreadState(current, prev, store, ctx, sessionFunctions, valueTasks, token); - - if (current.Phase != Phase.WAIT_FLUSH) return; - - if (ctx is null || !ctx.prevCtx.markers[EpochPhaseIdx.WaitFlush]) - { - var s = store._hybridLogCheckpoint.flushedSemaphore; - - var notify = s != null && s.CurrentCount > 0; - notify = notify || !store.SameCycle(ctx, current) || s == null; - - if (valueTasks != null && !notify) - { - Debug.Assert(s != null); - valueTasks.Add(new ValueTask(s.WaitAsync(token).ContinueWith(t => s.Release()))); - } - - if (!notify) return; - - if (ctx is not null) - ctx.prevCtx.markers[EpochPhaseIdx.WaitFlush] = true; - } - - store.epoch.Mark(EpochPhaseIdx.WaitFlush, current.Version); - if (store.epoch.CheckIsComplete(EpochPhaseIdx.WaitFlush, current.Version)) - store.GlobalStateMachineStep(current); - } - } - - /// - /// A Incremental Snapshot makes a copy of only changes that have happened since the last full Snapshot. It is - /// slower and more complex than a foldover, but more space-efficient on the log, and retains in-place - /// update performance as it does not advance the readonly marker unnecessarily. - /// - internal sealed class IncrementalSnapshotCheckpointTask : HybridLogCheckpointOrchestrationTask - where TStoreFunctions : IStoreFunctions - where TAllocator : IAllocator - { - /// - public override void GlobalBeforeEnteringState(SystemState next, TsavoriteKV store) - { - switch (next.Phase) - { - case Phase.PREPARE: - store._hybridLogCheckpoint = store._lastSnapshotCheckpoint; - base.GlobalBeforeEnteringState(next, store); - store._hybridLogCheckpoint.prevVersion = next.Version; - break; - case Phase.IN_PROGRESS: - base.GlobalBeforeEnteringState(next, store); - break; - case Phase.WAIT_FLUSH: - base.GlobalBeforeEnteringState(next, store); - store._hybridLogCheckpoint.info.finalLogicalAddress = store.hlogBase.GetTailAddress(); - - if (store._hybridLogCheckpoint.deltaLog == null) - { - store._hybridLogCheckpoint.deltaFileDevice = store.checkpointManager.GetDeltaLogDevice(store._hybridLogCheckpointToken); - store._hybridLogCheckpoint.deltaFileDevice.Initialize(-1); - store._hybridLogCheckpoint.deltaLog = new DeltaLog(store._hybridLogCheckpoint.deltaFileDevice, store.hlogBase.LogPageSizeBits, -1); - store._hybridLogCheckpoint.deltaLog.InitializeForWrites(store.hlogBase.bufferPool); - } - - // We are writing delta records outside epoch protection, so callee should be able to - // handle corrupted or unexpected concurrent page changes during the flush, e.g., by - // resuming epoch protection if necessary. Correctness is not affected as we will - // only read safe pages during recovery. - store.hlogBase.AsyncFlushDeltaToDevice( - store.hlogBase.FlushedUntilAddress, - store._hybridLogCheckpoint.info.finalLogicalAddress, - store._lastSnapshotCheckpoint.info.finalLogicalAddress, - store._hybridLogCheckpoint.prevVersion, - store._hybridLogCheckpoint.deltaLog, - out store._hybridLogCheckpoint.flushedSemaphore, - store.ThrottleCheckpointFlushDelayMs); - break; - case Phase.PERSISTENCE_CALLBACK: - CollectMetadata(next, store); - store._hybridLogCheckpoint.info.deltaTailAddress = store._hybridLogCheckpoint.deltaLog.TailAddress; - store.WriteHybridLogIncrementalMetaInfo(store._hybridLogCheckpoint.deltaLog); - store._hybridLogCheckpoint.info.deltaTailAddress = store._hybridLogCheckpoint.deltaLog.TailAddress; - store._lastSnapshotCheckpoint = store._hybridLogCheckpoint.Transfer(); - store._hybridLogCheckpoint.Dispose(); - break; - } - } - - /// - public override void OnThreadState( - SystemState current, - SystemState prev, TsavoriteKV store, - TsavoriteKV.TsavoriteExecutionContext ctx, - TSessionFunctionsWrapper sessionFunctions, - List valueTasks, - CancellationToken token = default) - { - base.OnThreadState(current, prev, store, ctx, sessionFunctions, valueTasks, token); - - if (current.Phase != Phase.WAIT_FLUSH) return; - - if (ctx is null || !ctx.prevCtx.markers[EpochPhaseIdx.WaitFlush]) - { - var s = store._hybridLogCheckpoint.flushedSemaphore; - - var notify = s != null && s.CurrentCount > 0; - notify = notify || !store.SameCycle(ctx, current) || s == null; - - if (valueTasks != null && !notify) - { - Debug.Assert(s != null); - valueTasks.Add(new ValueTask(s.WaitAsync(token).ContinueWith(t => s.Release()))); - } - - if (!notify) return; - - if (ctx is not null) - ctx.prevCtx.markers[EpochPhaseIdx.WaitFlush] = true; - } - - store.epoch.Mark(EpochPhaseIdx.WaitFlush, current.Version); - if (store.epoch.CheckIsComplete(EpochPhaseIdx.WaitFlush, current.Version)) - store.GlobalStateMachineStep(current); - } - } - - /// - /// - /// - internal class HybridLogCheckpointStateMachine : VersionChangeStateMachine - where TStoreFunctions : IStoreFunctions - where TAllocator : IAllocator - { - /// - /// Construct a new HybridLogCheckpointStateMachine to use the given checkpoint backend (either fold-over or - /// snapshot), drawing boundary at targetVersion. - /// - /// A task that encapsulates the logic to persist the checkpoint - /// upper limit (inclusive) of the version included - public HybridLogCheckpointStateMachine(ISynchronizationTask checkpointBackend, long targetVersion = -1) - : base(targetVersion, new VersionChangeTask(), checkpointBackend) { } - - /// - /// Construct a new HybridLogCheckpointStateMachine with the given tasks. Does not load any tasks by default. - /// - /// upper limit (inclusive) of the version included - /// The tasks to load onto the state machine - protected HybridLogCheckpointStateMachine(long targetVersion, params ISynchronizationTask[] tasks) - : base(targetVersion, tasks) { } - - /// - public override SystemState NextState(SystemState start) - { - var result = SystemState.Copy(ref start); - switch (start.Phase) - { - case Phase.IN_PROGRESS: - result.Phase = Phase.WAIT_FLUSH; - break; - case Phase.WAIT_FLUSH: - result.Phase = Phase.PERSISTENCE_CALLBACK; - break; - case Phase.PERSISTENCE_CALLBACK: - result.Phase = Phase.REST; - break; - default: - result = base.NextState(start); - break; - } - - return result; - } - } -} \ No newline at end of file diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/IncrementalSnapshotCheckpointTask.cs b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/IncrementalSnapshotCheckpointTask.cs new file mode 100644 index 0000000000..99f09dc2da --- /dev/null +++ b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/IncrementalSnapshotCheckpointTask.cs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; + +namespace Tsavorite.core +{ + /// + /// A Incremental Snapshot makes a copy of only changes that have happened since the last full Snapshot. It is + /// slower and more complex than a foldover, but more space-efficient on the log, and retains in-place + /// update performance as it does not advance the readonly marker unnecessarily. + /// + internal sealed class IncrementalSnapshotCheckpointTask : HybridLogCheckpointOrchestrationTask + where TStoreFunctions : IStoreFunctions + where TAllocator : IAllocator + { + /// + public override void GlobalBeforeEnteringState(SystemState next, TsavoriteKV store) + { + switch (next.Phase) + { + case Phase.PREPARE: + store._hybridLogCheckpoint = store._lastSnapshotCheckpoint; + base.GlobalBeforeEnteringState(next, store); + store._hybridLogCheckpoint.prevVersion = next.Version; + break; + case Phase.IN_PROGRESS: + base.GlobalBeforeEnteringState(next, store); + break; + case Phase.WAIT_FLUSH: + base.GlobalBeforeEnteringState(next, store); + store._hybridLogCheckpoint.info.finalLogicalAddress = store.hlogBase.GetTailAddress(); + + if (store._hybridLogCheckpoint.deltaLog == null) + { + store._hybridLogCheckpoint.deltaFileDevice = store.checkpointManager.GetDeltaLogDevice(store._hybridLogCheckpointToken); + store._hybridLogCheckpoint.deltaFileDevice.Initialize(-1); + store._hybridLogCheckpoint.deltaLog = new DeltaLog(store._hybridLogCheckpoint.deltaFileDevice, store.hlogBase.LogPageSizeBits, -1); + store._hybridLogCheckpoint.deltaLog.InitializeForWrites(store.hlogBase.bufferPool); + } + + // We are writing delta records outside epoch protection, so callee should be able to + // handle corrupted or unexpected concurrent page changes during the flush, e.g., by + // resuming epoch protection if necessary. Correctness is not affected as we will + // only read safe pages during recovery. + store.hlogBase.AsyncFlushDeltaToDevice( + store.hlogBase.FlushedUntilAddress, + store._hybridLogCheckpoint.info.finalLogicalAddress, + store._lastSnapshotCheckpoint.info.finalLogicalAddress, + store._hybridLogCheckpoint.prevVersion, + store._hybridLogCheckpoint.deltaLog, + out store._hybridLogCheckpoint.flushedSemaphore, + store.ThrottleCheckpointFlushDelayMs); + break; + case Phase.PERSISTENCE_CALLBACK: + CollectMetadata(next, store); + store._hybridLogCheckpoint.info.deltaTailAddress = store._hybridLogCheckpoint.deltaLog.TailAddress; + store.WriteHybridLogIncrementalMetaInfo(store._hybridLogCheckpoint.deltaLog); + store._hybridLogCheckpoint.info.deltaTailAddress = store._hybridLogCheckpoint.deltaLog.TailAddress; + store._lastSnapshotCheckpoint = store._hybridLogCheckpoint.Transfer(); + store._hybridLogCheckpoint.Dispose(); + break; + } + } + + /// + public override void OnThreadState( + SystemState current, + SystemState prev, TsavoriteKV store, + TsavoriteKV.TsavoriteExecutionContext ctx, + TSessionFunctionsWrapper sessionFunctions, + List valueTasks, + CancellationToken token = default) + { + base.OnThreadState(current, prev, store, ctx, sessionFunctions, valueTasks, token); + + if (current.Phase != Phase.WAIT_FLUSH) return; + + if (ctx is null || !ctx.prevCtx.markers[EpochPhaseIdx.WaitFlush]) + { + var s = store._hybridLogCheckpoint.flushedSemaphore; + + var notify = s != null && s.CurrentCount > 0; + notify = notify || !store.SameCycle(ctx, current) || s == null; + + if (valueTasks != null && !notify) + { + Debug.Assert(s != null); + valueTasks.Add(new ValueTask(s.WaitAsync(token).ContinueWith(t => s.Release()))); + } + + if (!notify) return; + + if (ctx is not null) + ctx.prevCtx.markers[EpochPhaseIdx.WaitFlush] = true; + } + + store.epoch.Mark(EpochPhaseIdx.WaitFlush, current.Version); + if (store.epoch.CheckIsComplete(EpochPhaseIdx.WaitFlush, current.Version)) + store.GlobalStateMachineStep(current); + } + } +} \ No newline at end of file diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/SnapshotCheckpointTask.cs b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/SnapshotCheckpointTask.cs new file mode 100644 index 0000000000..e850e156a1 --- /dev/null +++ b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/SnapshotCheckpointTask.cs @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; + +namespace Tsavorite.core +{ + /// + /// A Snapshot persists a version by making a copy for every entry of that version separate from the log. It is + /// slower and more complex than a foldover, but more space-efficient on the log, and retains in-place + /// update performance as it does not advance the readonly marker unnecessarily. + /// + internal sealed class SnapshotCheckpointTask : HybridLogCheckpointOrchestrationTask + where TStoreFunctions : IStoreFunctions + where TAllocator : IAllocator + { + /// + public override void GlobalBeforeEnteringState(SystemState next, TsavoriteKV store) + { + switch (next.Phase) + { + case Phase.PREPARE: + store._lastSnapshotCheckpoint.Dispose(); + base.GlobalBeforeEnteringState(next, store); + store._hybridLogCheckpoint.info.useSnapshotFile = 1; + break; + case Phase.WAIT_FLUSH: + base.GlobalBeforeEnteringState(next, store); + store._hybridLogCheckpoint.info.finalLogicalAddress = store.hlogBase.GetTailAddress(); + store._hybridLogCheckpoint.info.snapshotFinalLogicalAddress = store._hybridLogCheckpoint.info.finalLogicalAddress; + + store._hybridLogCheckpoint.snapshotFileDevice = + store.checkpointManager.GetSnapshotLogDevice(store._hybridLogCheckpointToken); + store._hybridLogCheckpoint.snapshotFileObjectLogDevice = + store.checkpointManager.GetSnapshotObjectLogDevice(store._hybridLogCheckpointToken); + store._hybridLogCheckpoint.snapshotFileDevice.Initialize(store.hlogBase.GetSegmentSize()); + store._hybridLogCheckpoint.snapshotFileObjectLogDevice.Initialize(-1); + + // If we are using a NullDevice then storage tier is not enabled and FlushedUntilAddress may be ReadOnlyAddress; get all records in memory. + store._hybridLogCheckpoint.info.snapshotStartFlushedLogicalAddress = store.hlogBase.IsNullDevice ? store.hlogBase.HeadAddress : store.hlogBase.FlushedUntilAddress; + + long startPage = store.hlogBase.GetPage(store._hybridLogCheckpoint.info.snapshotStartFlushedLogicalAddress); + long endPage = store.hlogBase.GetPage(store._hybridLogCheckpoint.info.finalLogicalAddress); + if (store._hybridLogCheckpoint.info.finalLogicalAddress > + store.hlog.GetStartLogicalAddress(endPage)) + { + endPage++; + } + + // We are writing pages outside epoch protection, so callee should be able to + // handle corrupted or unexpected concurrent page changes during the flush, e.g., by + // resuming epoch protection if necessary. Correctness is not affected as we will + // only read safe pages during recovery. + store.hlogBase.AsyncFlushPagesToDevice( + startPage, + endPage, + store._hybridLogCheckpoint.info.finalLogicalAddress, + store._hybridLogCheckpoint.info.startLogicalAddress, + store._hybridLogCheckpoint.snapshotFileDevice, + store._hybridLogCheckpoint.snapshotFileObjectLogDevice, + out store._hybridLogCheckpoint.flushedSemaphore, + store.ThrottleCheckpointFlushDelayMs); + break; + case Phase.PERSISTENCE_CALLBACK: + // Set actual FlushedUntil to the latest possible data in main log that is on disk + // If we are using a NullDevice then storage tier is not enabled and FlushedUntilAddress may be ReadOnlyAddress; get all records in memory. + store._hybridLogCheckpoint.info.flushedLogicalAddress = store.hlogBase.IsNullDevice ? store.hlogBase.HeadAddress : store.hlogBase.FlushedUntilAddress; + base.GlobalBeforeEnteringState(next, store); + store._lastSnapshotCheckpoint = store._hybridLogCheckpoint.Transfer(); + break; + default: + base.GlobalBeforeEnteringState(next, store); + break; + } + } + + /// + public override void OnThreadState( + SystemState current, + SystemState prev, TsavoriteKV store, + TsavoriteKV.TsavoriteExecutionContext ctx, + TSessionFunctionsWrapper sessionFunctions, + List valueTasks, + CancellationToken token = default) + { + base.OnThreadState(current, prev, store, ctx, sessionFunctions, valueTasks, token); + + if (current.Phase != Phase.WAIT_FLUSH) return; + + if (ctx is null || !ctx.prevCtx.markers[EpochPhaseIdx.WaitFlush]) + { + var s = store._hybridLogCheckpoint.flushedSemaphore; + + var notify = s != null && s.CurrentCount > 0; + notify = notify || !store.SameCycle(ctx, current) || s == null; + + if (valueTasks != null && !notify) + { + Debug.Assert(s != null); + valueTasks.Add(new ValueTask(s.WaitAsync(token).ContinueWith(t => s.Release()))); + } + + if (!notify) return; + + if (ctx is not null) + ctx.prevCtx.markers[EpochPhaseIdx.WaitFlush] = true; + } + + store.epoch.Mark(EpochPhaseIdx.WaitFlush, current.Version); + if (store.epoch.CheckIsComplete(EpochPhaseIdx.WaitFlush, current.Version)) + store.GlobalStateMachineStep(current); + } + } +} \ No newline at end of file diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/StateTransitions.cs b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/StateTransitions.cs index 0e75362419..e845f559d3 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/StateTransitions.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/StateTransitions.cs @@ -47,6 +47,9 @@ public enum Phase : int /// Wait for an index-only checkpoint to complete WAIT_INDEX_ONLY_CHECKPOINT, + /// Wait for pre-scan (until ReadOnlyAddress) to complete for streaming snapshot + PREP_STREAMING_SNAPSHOT_CHECKPOINT, + /// Prepare for a checkpoint, still in (v) version PREPARE, diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/StreamingSnapshotCheckpointStateMachine.cs b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/StreamingSnapshotCheckpointStateMachine.cs new file mode 100644 index 0000000000..29128bc656 --- /dev/null +++ b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/StreamingSnapshotCheckpointStateMachine.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +namespace Tsavorite.core +{ + /// + /// StreamingSnapshot checkpoint state machine. + /// + class StreamingSnapshotCheckpointStateMachine : VersionChangeStateMachine + where TStoreFunctions : IStoreFunctions + where TAllocator : IAllocator + { + /// + /// Construct a new StreamingSnapshotCheckpointStateMachine, drawing boundary at targetVersion. + /// + /// upper limit (inclusive) of the version included + public StreamingSnapshotCheckpointStateMachine(long targetVersion) + : base(targetVersion, + new VersionChangeTask(), + new StreamingSnapshotCheckpointTask(targetVersion)) + { } + + /// + public override SystemState NextState(SystemState start) + { + var result = SystemState.Copy(ref start); + switch (start.Phase) + { + case Phase.REST: + result.Phase = Phase.PREP_STREAMING_SNAPSHOT_CHECKPOINT; + break; + case Phase.PREP_STREAMING_SNAPSHOT_CHECKPOINT: + result.Phase = Phase.PREPARE; + break; + case Phase.IN_PROGRESS: + result.Phase = Phase.WAIT_FLUSH; + break; + case Phase.WAIT_FLUSH: + result.Phase = Phase.REST; + break; + default: + result = base.NextState(start); + break; + } + + return result; + } + } +} \ No newline at end of file diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/StreamingSnapshotCheckpointTask.cs b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/StreamingSnapshotCheckpointTask.cs new file mode 100644 index 0000000000..2c932f196b --- /dev/null +++ b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/StreamingSnapshotCheckpointTask.cs @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; + +namespace Tsavorite.core +{ + + /// + /// A Streaming Snapshot persists a version by yielding a stream of key-value pairs that correspond to + /// a consistent snapshot of the database, for the old version (v). Unlike Snapshot, StreamingSnapshot + /// is designed to not require tail growth even during the WAIT_FLUSH phase of checkpointing. Further, + /// it does not require a snapshot of the index. Recovery is achieved by replaying the yielded log + /// of key-value pairs and inserting each record into an empty database. + /// + sealed class StreamingSnapshotCheckpointTask : HybridLogCheckpointOrchestrationTask + where TStoreFunctions : IStoreFunctions + where TAllocator : IAllocator + { + readonly long targetVersion; + + public StreamingSnapshotCheckpointTask(long targetVersion) + { + this.targetVersion = targetVersion; + } + + /// + public override void GlobalBeforeEnteringState(SystemState next, TsavoriteKV store) + { + switch (next.Phase) + { + case Phase.PREP_STREAMING_SNAPSHOT_CHECKPOINT: + base.GlobalBeforeEnteringState(next, store); + store._hybridLogCheckpointToken = Guid.NewGuid(); + store._hybridLogCheckpoint.info.version = next.Version; + store._hybridLogCheckpoint.info.nextVersion = targetVersion == -1 ? next.Version + 1 : targetVersion; + store._lastSnapshotCheckpoint.Dispose(); + _ = Task.Run(store.StreamingSnapshotScanPhase1); + break; + case Phase.PREPARE: + store.InitializeHybridLogCheckpoint(store._hybridLogCheckpointToken, next.Version); + base.GlobalBeforeEnteringState(next, store); + break; + case Phase.WAIT_FLUSH: + base.GlobalBeforeEnteringState(next, store); + store._hybridLogCheckpoint.flushedSemaphore = new SemaphoreSlim(0); + var finalLogicalAddress = store.hlogBase.GetTailAddress(); + Task.Run(() => store.StreamingSnapshotScanPhase2(finalLogicalAddress)); + break; + default: + base.GlobalBeforeEnteringState(next, store); + break; + } + } + + /// + public override void OnThreadState( + SystemState current, + SystemState prev, TsavoriteKV store, + TsavoriteKV.TsavoriteExecutionContext ctx, + TSessionFunctionsWrapper sessionFunctions, + List valueTasks, + CancellationToken token = default) + { + base.OnThreadState(current, prev, store, ctx, sessionFunctions, valueTasks, token); + + if (current.Phase != Phase.WAIT_FLUSH) return; + + if (ctx is null) + { + var s = store._hybridLogCheckpoint.flushedSemaphore; + + var notify = s != null && s.CurrentCount > 0; + notify = notify || !store.SameCycle(ctx, current) || s == null; + + if (valueTasks != null && !notify) + { + Debug.Assert(s != null); + valueTasks.Add(new ValueTask(s.WaitAsync(token).ContinueWith(t => s.Release()))); + } + } + } + } +} \ No newline at end of file diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/StreamingSnapshotTsavoriteKV.cs b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/StreamingSnapshotTsavoriteKV.cs new file mode 100644 index 0000000000..36c4c1aa75 --- /dev/null +++ b/libs/storage/Tsavorite/cs/src/core/Index/Synchronization/StreamingSnapshotTsavoriteKV.cs @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Diagnostics; + +namespace Tsavorite.core +{ + public partial class TsavoriteKV : TsavoriteBase + where TStoreFunctions : IStoreFunctions + where TAllocator : IAllocator + { + IStreamingSnapshotIteratorFunctions streamingSnapshotIteratorFunctions; + long scannedUntilAddressCursor; + long numberOfRecords; + + class StreamingSnapshotSessionFunctions : SessionFunctionsBase + { + + } + + class ScanPhase1Functions : IScanIteratorFunctions + { + readonly IStreamingSnapshotIteratorFunctions streamingSnapshotIteratorFunctions; + readonly Guid checkpointToken; + readonly long currentVersion; + readonly long targetVersion; + public long numberOfRecords; + + public ScanPhase1Functions(IStreamingSnapshotIteratorFunctions streamingSnapshotIteratorFunctions, Guid checkpointToken, long currentVersion, long targetVersion) + { + this.streamingSnapshotIteratorFunctions = streamingSnapshotIteratorFunctions; + this.checkpointToken = checkpointToken; + this.currentVersion = currentVersion; + this.targetVersion = targetVersion; + } + + /// + public bool SingleReader(ref TKey key, ref TValue value, RecordMetadata recordMetadata, long numberOfRecords, out CursorRecordResult cursorRecordResult) + { + cursorRecordResult = CursorRecordResult.Accept; + return streamingSnapshotIteratorFunctions.Reader(ref key, ref value, recordMetadata, numberOfRecords); + } + + /// + public bool ConcurrentReader(ref TKey key, ref TValue value, RecordMetadata recordMetadata, long numberOfRecords, out CursorRecordResult cursorRecordResult) + => SingleReader(ref key, ref value, recordMetadata, numberOfRecords, out cursorRecordResult); + + /// + public void OnException(Exception exception, long numberOfRecords) + => streamingSnapshotIteratorFunctions.OnException(exception, numberOfRecords); + + /// + public bool OnStart(long beginAddress, long endAddress) + => streamingSnapshotIteratorFunctions.OnStart(checkpointToken, currentVersion, targetVersion); + + /// + public void OnStop(bool completed, long numberOfRecords) + { + this.numberOfRecords = numberOfRecords; + } + } + + internal void StreamingSnapshotScanPhase1() + { + try + { + Debug.Assert(systemState.Phase == Phase.PREP_STREAMING_SNAPSHOT_CHECKPOINT); + + // Iterate all the read-only records in the store + scannedUntilAddressCursor = Log.SafeReadOnlyAddress; + var scanFunctions = new ScanPhase1Functions(streamingSnapshotIteratorFunctions, _hybridLogCheckpointToken, _hybridLogCheckpoint.info.version, _hybridLogCheckpoint.info.nextVersion); + using var s = NewSession(new()); + long cursor = 0; + _ = s.ScanCursor(ref cursor, long.MaxValue, scanFunctions, scannedUntilAddressCursor); + this.numberOfRecords = scanFunctions.numberOfRecords; + } + finally + { + Debug.Assert(systemState.Phase == Phase.PREP_STREAMING_SNAPSHOT_CHECKPOINT); + GlobalStateMachineStep(systemState); + } + } + + class ScanPhase2Functions : IScanIteratorFunctions + { + readonly IStreamingSnapshotIteratorFunctions streamingSnapshotIteratorFunctions; + readonly long phase1NumberOfRecords; + + public ScanPhase2Functions(IStreamingSnapshotIteratorFunctions streamingSnapshotIteratorFunctions, long acceptedRecordCount) + { + this.streamingSnapshotIteratorFunctions = streamingSnapshotIteratorFunctions; + this.phase1NumberOfRecords = acceptedRecordCount; + } + + /// + public bool SingleReader(ref TKey key, ref TValue value, RecordMetadata recordMetadata, long numberOfRecords, out CursorRecordResult cursorRecordResult) + { + cursorRecordResult = CursorRecordResult.Accept; + return streamingSnapshotIteratorFunctions.Reader(ref key, ref value, recordMetadata, numberOfRecords); + } + + /// + public bool ConcurrentReader(ref TKey key, ref TValue value, RecordMetadata recordMetadata, long numberOfRecords, out CursorRecordResult cursorRecordResult) + => SingleReader(ref key, ref value, recordMetadata, numberOfRecords, out cursorRecordResult); + + /// + public void OnException(Exception exception, long numberOfRecords) + => streamingSnapshotIteratorFunctions.OnException(exception, numberOfRecords); + + /// + public bool OnStart(long beginAddress, long endAddress) => true; + + /// + public void OnStop(bool completed, long numberOfRecords) + => streamingSnapshotIteratorFunctions.OnStop(completed, phase1NumberOfRecords + numberOfRecords); + } + + internal void StreamingSnapshotScanPhase2(long untilAddress) + { + try + { + Debug.Assert(systemState.Phase == Phase.WAIT_FLUSH); + + // Iterate all the (v) records in the store + var scanFunctions = new ScanPhase2Functions(streamingSnapshotIteratorFunctions, this.numberOfRecords); + using var s = NewSession(new()); + + // TODO: This requires ScanCursor to provide a consistent snapshot considering only records up to untilAddress + // There is a bug in the current implementation of ScanCursor, where it does not provide such a consistent snapshot + _ = s.ScanCursor(ref scannedUntilAddressCursor, long.MaxValue, scanFunctions, endAddress: untilAddress, maxAddress: untilAddress); + + // Reset the cursor to 0 + scannedUntilAddressCursor = 0; + numberOfRecords = 0; + + // Reset the callback functions + streamingSnapshotIteratorFunctions = null; + + // Release the semaphore to allow the checkpoint waiting task to proceed + _hybridLogCheckpoint.flushedSemaphore.Release(); + } + finally + { + Debug.Assert(systemState.Phase == Phase.WAIT_FLUSH); + GlobalStateMachineStep(systemState); + } + } + } +} \ No newline at end of file diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/ConditionalCopyToTail.cs b/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/ConditionalCopyToTail.cs index 04bf1aeff3..1c66b5bbb0 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/ConditionalCopyToTail.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/ConditionalCopyToTail.cs @@ -29,7 +29,7 @@ public unsafe partial class TsavoriteKV(TSessionFunctionsWrapper sessionFunctions, ref PendingContext pendingContext, ref TKey key, ref TInput input, ref TValue value, ref TOutput output, TContext userContext, - ref OperationStackContext stackCtx, WriteReason writeReason, bool wantIO = true) + ref OperationStackContext stackCtx, WriteReason writeReason, bool wantIO = true, long maxAddress = long.MaxValue) where TSessionFunctionsWrapper : ISessionFunctionsWrapper { bool callerHasTransientLock = stackCtx.recSrc.HasTransientSLock; @@ -70,7 +70,7 @@ private OperationStatus ConditionalCopyToTail(sessionFunctions, ref key, ref stackCtx2, stackCtx.recSrc.LogicalAddress, minAddress, out status, out needIO)) + if (TryFindRecordInMainLogForConditionalOperation(sessionFunctions, ref key, ref stackCtx2, stackCtx.recSrc.LogicalAddress, minAddress, maxAddress, out status, out needIO)) return OperationStatus.SUCCESS; } while (HandleImmediateNonPendingRetryStatus(status, sessionFunctions)); @@ -84,13 +84,13 @@ private OperationStatus ConditionalCopyToTail(TSessionFunctionsWrapper sessionFunctions, ref TKey key, ref TInput input, ref TValue value, - ref TOutput output, long currentAddress, long minAddress) + ref TOutput output, long currentAddress, long minAddress, long maxAddress = long.MaxValue) where TSessionFunctionsWrapper : ISessionFunctionsWrapper { Debug.Assert(epoch.ThisInstanceProtected(), "This is called only from Compaction so the epoch should be protected"); @@ -101,16 +101,16 @@ internal Status CompactionConditionalCopyToTail(sessionFunctions, ref key, ref stackCtx, currentAddress, minAddress, out status, out needIO)) + if (TryFindRecordInMainLogForConditionalOperation(sessionFunctions, ref key, ref stackCtx, currentAddress, minAddress, maxAddress, out status, out needIO)) return Status.CreateFound(); } while (sessionFunctions.Store.HandleImmediateNonPendingRetryStatus(status, sessionFunctions)); if (needIO) status = PrepareIOForConditionalOperation(sessionFunctions, ref pendingContext, ref key, ref input, ref value, ref output, default, - ref stackCtx, minAddress, WriteReason.Compaction); + ref stackCtx, minAddress, maxAddress, WriteReason.Compaction); else - status = ConditionalCopyToTail(sessionFunctions, ref pendingContext, ref key, ref input, ref value, ref output, default, ref stackCtx, WriteReason.Compaction); + status = ConditionalCopyToTail(sessionFunctions, ref pendingContext, ref key, ref input, ref value, ref output, default, ref stackCtx, WriteReason.Compaction, true, maxAddress); return HandleOperationStatus(sessionFunctions.Ctx, ref pendingContext, status, out _); } @@ -118,12 +118,13 @@ internal Status CompactionConditionalCopyToTail(TSessionFunctionsWrapper sessionFunctions, ref PendingContext pendingContext, ref TKey key, ref TInput input, ref TValue value, ref TOutput output, TContext userContext, - ref OperationStackContext stackCtx, long minAddress, WriteReason writeReason, + ref OperationStackContext stackCtx, long minAddress, long maxAddress, WriteReason writeReason, OperationType opType = OperationType.CONDITIONAL_INSERT) where TSessionFunctionsWrapper : ISessionFunctionsWrapper { pendingContext.type = opType; pendingContext.minAddress = minAddress; + pendingContext.maxAddress = maxAddress; pendingContext.writeReason = writeReason; pendingContext.InitialEntryAddress = Constants.kInvalidAddress; pendingContext.InitialLatestLogicalAddress = stackCtx.recSrc.LatestLogicalAddress; diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/ContinuePending.cs b/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/ContinuePending.cs index a1c2a547d8..2d0b56fb76 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/ContinuePending.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/ContinuePending.cs @@ -293,14 +293,14 @@ internal OperationStatus ContinuePendingConditionalCopyToTail(sessionFunctions, ref key, ref stackCtx, currentAddress: request.logicalAddress, minAddress, out internalStatus, out bool needIO)) + if (TryFindRecordInMainLogForConditionalOperation(sessionFunctions, ref key, ref stackCtx, currentAddress: request.logicalAddress, minAddress, pendingContext.maxAddress, out internalStatus, out bool needIO)) return OperationStatus.SUCCESS; if (!OperationStatusUtils.IsRetry(internalStatus)) { // HeadAddress may have risen above minAddress; if so, we need IO. internalStatus = needIO ? PrepareIOForConditionalOperation(sessionFunctions, ref pendingContext, ref key, ref pendingContext.input.Get(), ref pendingContext.value.Get(), - ref pendingContext.output, pendingContext.userContext, ref stackCtx, minAddress, WriteReason.Compaction) + ref pendingContext.output, pendingContext.userContext, ref stackCtx, minAddress, pendingContext.maxAddress, WriteReason.Compaction) : ConditionalCopyToTail(sessionFunctions, ref pendingContext, ref key, ref pendingContext.input.Get(), ref pendingContext.value.Get(), ref pendingContext.output, pendingContext.userContext, ref stackCtx, pendingContext.writeReason); } @@ -344,7 +344,7 @@ internal OperationStatus ContinuePendingConditionalScanPush(sessionFunctions, pendingContext.scanCursorState, pendingContext.recordInfo, ref pendingContext.key.Get(), ref pendingContext.value.Get(), - currentAddress: request.logicalAddress, minAddress: pendingContext.InitialLatestLogicalAddress + 1); + currentAddress: request.logicalAddress, minAddress: pendingContext.InitialLatestLogicalAddress + 1, maxAddress: pendingContext.maxAddress); // ConditionalScanPush has already called HandleOperationStatus, so return SUCCESS here. return OperationStatus.SUCCESS; diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/FindRecord.cs b/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/FindRecord.cs index 335197d956..0c5a82c5fb 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/FindRecord.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/FindRecord.cs @@ -48,10 +48,22 @@ internal bool TryFindRecordInMainLog(ref TKey key, ref OperationStackContext stackCtx, long minAddress, long maxAddress) + { + Debug.Assert(!stackCtx.recSrc.HasInMemorySrc, "Should not have found record before this call"); + if (stackCtx.recSrc.LogicalAddress >= minAddress) + { + stackCtx.recSrc.SetPhysicalAddress(); + TraceBackForKeyMatch(ref key, ref stackCtx.recSrc, minAddress, maxAddress); + } + return stackCtx.recSrc.HasInMemorySrc; + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] // Return true if the record is found in the log, else false and an indication of whether we need to do IO to continue the search internal bool TryFindRecordInMainLogForConditionalOperation(TSessionFunctionsWrapper sessionFunctions, - ref TKey key, ref OperationStackContext stackCtx, long currentAddress, long minAddress, out OperationStatus internalStatus, out bool needIO) + ref TKey key, ref OperationStackContext stackCtx, long currentAddress, long minAddress, long maxAddress, out OperationStatus internalStatus, out bool needIO) where TSessionFunctionsWrapper : ISessionFunctionsWrapper { if (!FindTag(ref stackCtx.hei)) @@ -94,7 +106,7 @@ internal bool TryFindRecordInMainLogForConditionalOperation= minAddress && stackCtx.recSrc.LogicalAddress < hlogBase.HeadAddress && stackCtx.recSrc.LogicalAddress >= hlogBase.BeginAddress; @@ -130,6 +142,27 @@ private bool TraceBackForKeyMatch(ref TKey key, ref RecordSource recSrc, long minAddress, long maxAddress) + { + // PhysicalAddress must already be populated by callers. + ref var recordInfo = ref recSrc.GetInfo(); + if (IsValidTracebackRecord(recordInfo) && recSrc.LogicalAddress < maxAddress && storeFunctions.KeysEqual(ref key, ref recSrc.GetKey())) + { + recSrc.SetHasMainLogSrc(); + return true; + } + + recSrc.LogicalAddress = recordInfo.PreviousAddress; + if (TraceBackForKeyMatch(ref key, recSrc.LogicalAddress, minAddress, maxAddress, out recSrc.LogicalAddress, out recSrc.PhysicalAddress)) + { + recSrc.SetHasMainLogSrc(); + return true; + } + return false; + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private bool TraceBackForKeyMatch(ref TKey key, long fromLogicalAddress, long minAddress, out long foundLogicalAddress, out long foundPhysicalAddress) { @@ -149,6 +182,26 @@ private bool TraceBackForKeyMatch(ref TKey key, long fromLogicalAddress, long mi return false; } + // Overload with maxAddress to avoid the extra condition - TODO: check that this duplication saves on IL/perf + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private bool TraceBackForKeyMatch(ref TKey key, long fromLogicalAddress, long minAddress, long maxAddress, out long foundLogicalAddress, out long foundPhysicalAddress) + { + // This overload is called when the record at the "current" logical address does not match 'key'; fromLogicalAddress is its .PreviousAddress. + foundLogicalAddress = fromLogicalAddress; + while (foundLogicalAddress >= minAddress) + { + foundPhysicalAddress = hlog.GetPhysicalAddress(foundLogicalAddress); + + ref var recordInfo = ref hlog.GetInfo(foundPhysicalAddress); + if (IsValidTracebackRecord(recordInfo) && foundLogicalAddress < maxAddress && storeFunctions.KeysEqual(ref key, ref hlog.GetKey(foundPhysicalAddress))) + return true; + + foundLogicalAddress = recordInfo.PreviousAddress; + } + foundPhysicalAddress = Constants.kInvalidAddress; + return false; + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private bool TryFindRecordForUpdate(ref TKey key, ref OperationStackContext stackCtx, long minAddress, out OperationStatus internalStatus) { @@ -204,14 +257,14 @@ private bool TryFindRecordForPendingOperation(ref TKe } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private bool TryFindRecordInMainLogForPendingOperation(ref TKey key, ref OperationStackContext stackCtx, long minAddress, out OperationStatus internalStatus) + private bool TryFindRecordInMainLogForPendingOperation(ref TKey key, ref OperationStackContext stackCtx, long minAddress, long maxAddress, out OperationStatus internalStatus) { // This overload is called when we do not have a PendingContext to get minAddress from, and we've skipped the readcache if present. // This routine returns true if we find the key, else false. internalStatus = OperationStatus.SUCCESS; - if (!TryFindRecordInMainLog(ref key, ref stackCtx, minAddress)) + if (!TryFindRecordInMainLog(ref key, ref stackCtx, minAddress, maxAddress)) return false; if (stackCtx.recSrc.GetInfo().IsClosed) internalStatus = OperationStatus.RETRY_LATER; diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Tsavorite.cs b/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Tsavorite.cs index 4e106a6116..111cfbc7aa 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Tsavorite.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Tsavorite.cs @@ -188,26 +188,38 @@ public TsavoriteKV(KVSettings kvSettings, TStoreFunctions storeFun /// than current version. Actual new version may have version number greater than supplied number. If the supplied /// number is -1, checkpoint will unconditionally create a new version. /// + /// Iterator for streaming snapshot records /// /// Whether we successfully initiated the checkpoint (initiation may /// fail if we are already taking a checkpoint or performing some other /// operation such as growing the index). Use CompleteCheckpointAsync to wait completion. /// - public bool TryInitiateFullCheckpoint(out Guid token, CheckpointType checkpointType, long targetVersion = -1) + public bool TryInitiateFullCheckpoint(out Guid token, CheckpointType checkpointType, long targetVersion = -1, IStreamingSnapshotIteratorFunctions streamingSnapshotIteratorFunctions = null) { - ISynchronizationTask backend; + token = default; + bool result; if (checkpointType == CheckpointType.FoldOver) - backend = new FoldOverCheckpointTask(); + { + var backend = new FoldOverCheckpointTask(); + result = StartStateMachine(new FullCheckpointStateMachine(backend, targetVersion)); + } else if (checkpointType == CheckpointType.Snapshot) - backend = new SnapshotCheckpointTask(); + { + var backend = new SnapshotCheckpointTask(); + result = StartStateMachine(new FullCheckpointStateMachine(backend, targetVersion)); + } + else if (checkpointType == CheckpointType.StreamingSnapshot) + { + if (streamingSnapshotIteratorFunctions is null) + throw new TsavoriteException("StreamingSnapshot checkpoint requires a streaming snapshot iterator"); + this.streamingSnapshotIteratorFunctions = streamingSnapshotIteratorFunctions; + result = StartStateMachine(new StreamingSnapshotCheckpointStateMachine(targetVersion)); + } else throw new TsavoriteException("Unsupported full checkpoint type"); - var result = StartStateMachine(new FullCheckpointStateMachine(backend, targetVersion)); if (result) token = _hybridLogCheckpointToken; - else - token = default; return result; } @@ -221,6 +233,7 @@ public bool TryInitiateFullCheckpoint(out Guid token, CheckpointType checkpointT /// than current version. Actual new version may have version number greater than supplied number. If the supplied /// number is -1, checkpoint will unconditionally create a new version. /// + /// Iterator for streaming snapshot records /// /// (bool success, Guid token) /// success: Whether we successfully initiated the checkpoint (initiation may @@ -230,9 +243,9 @@ public bool TryInitiateFullCheckpoint(out Guid token, CheckpointType checkpointT /// Await task to complete checkpoint, if initiated successfully /// public async ValueTask<(bool success, Guid token)> TakeFullCheckpointAsync(CheckpointType checkpointType, - CancellationToken cancellationToken = default, long targetVersion = -1) + CancellationToken cancellationToken = default, long targetVersion = -1, IStreamingSnapshotIteratorFunctions streamingSnapshotIteratorFunctions = null) { - var success = TryInitiateFullCheckpoint(out Guid token, checkpointType, targetVersion); + var success = TryInitiateFullCheckpoint(out Guid token, checkpointType, targetVersion, streamingSnapshotIteratorFunctions); if (success) await CompleteCheckpointAsync(cancellationToken).ConfigureAwait(false); @@ -287,23 +300,36 @@ public bool TryInitiateIndexCheckpoint(out Guid token) /// /// Whether we could initiate the checkpoint. Use CompleteCheckpointAsync to wait completion. public bool TryInitiateHybridLogCheckpoint(out Guid token, CheckpointType checkpointType, bool tryIncremental = false, - long targetVersion = -1) + long targetVersion = -1, IStreamingSnapshotIteratorFunctions streamingSnapshotIteratorFunctions = null) { - ISynchronizationTask backend; + token = default; + bool result; if (checkpointType == CheckpointType.FoldOver) - backend = new FoldOverCheckpointTask(); + { + var backend = new FoldOverCheckpointTask(); + result = StartStateMachine(new HybridLogCheckpointStateMachine(backend, targetVersion)); + } else if (checkpointType == CheckpointType.Snapshot) { + ISynchronizationTask backend; if (tryIncremental && _lastSnapshotCheckpoint.info.guid != default && _lastSnapshotCheckpoint.info.finalLogicalAddress > hlogBase.FlushedUntilAddress && !hlog.HasObjectLog) backend = new IncrementalSnapshotCheckpointTask(); else backend = new SnapshotCheckpointTask(); + result = StartStateMachine(new HybridLogCheckpointStateMachine(backend, targetVersion)); + } + else if (checkpointType == CheckpointType.StreamingSnapshot) + { + if (streamingSnapshotIteratorFunctions is null) + throw new TsavoriteException("StreamingSnapshot checkpoint requires a streaming snapshot iterator"); + this.streamingSnapshotIteratorFunctions = streamingSnapshotIteratorFunctions; + result = StartStateMachine(new StreamingSnapshotCheckpointStateMachine(targetVersion)); } else - throw new TsavoriteException("Unsupported checkpoint type"); + throw new TsavoriteException("Unsupported hybrid log checkpoint type"); - var result = StartStateMachine(new HybridLogCheckpointStateMachine(backend, targetVersion)); - token = _hybridLogCheckpointToken; + if (result) + token = _hybridLogCheckpointToken; return result; } diff --git a/libs/storage/Tsavorite/cs/test/LargeObjectTests.cs b/libs/storage/Tsavorite/cs/test/LargeObjectTests.cs index c78b2f6fc4..952b582641 100644 --- a/libs/storage/Tsavorite/cs/test/LargeObjectTests.cs +++ b/libs/storage/Tsavorite/cs/test/LargeObjectTests.cs @@ -25,7 +25,9 @@ internal class LargeObjectTests [Test] [Category("TsavoriteKV")] - public async ValueTask LargeObjectTest([Values] CheckpointType checkpointType) + public async ValueTask LargeObjectTest( + [Values(CheckpointType.Snapshot, CheckpointType.FoldOver)] CheckpointType checkpointType + ) { int maxSize = 100; int numOps = 5000; diff --git a/libs/storage/Tsavorite/cs/test/ObjectRecoveryTest2.cs b/libs/storage/Tsavorite/cs/test/ObjectRecoveryTest2.cs index 9dbdb222e5..70a2c06353 100644 --- a/libs/storage/Tsavorite/cs/test/ObjectRecoveryTest2.cs +++ b/libs/storage/Tsavorite/cs/test/ObjectRecoveryTest2.cs @@ -36,7 +36,7 @@ public void TearDown() [Category("Smoke")] public async ValueTask ObjectRecoveryTest2( - [Values] CheckpointType checkpointType, + [Values(CheckpointType.Snapshot, CheckpointType.FoldOver)] CheckpointType checkpointType, [Range(300, 700, 300)] int iterations, [Values] bool isAsync) { diff --git a/libs/storage/Tsavorite/cs/test/ObjectRecoveryTest3.cs b/libs/storage/Tsavorite/cs/test/ObjectRecoveryTest3.cs index 251a5bbc67..ef716e5972 100644 --- a/libs/storage/Tsavorite/cs/test/ObjectRecoveryTest3.cs +++ b/libs/storage/Tsavorite/cs/test/ObjectRecoveryTest3.cs @@ -35,7 +35,7 @@ public void TearDown() [Test] [Category("TsavoriteKV"), Category("CheckpointRestore")] public async ValueTask ObjectRecoveryTest3( - [Values] CheckpointType checkpointType, + [Values(CheckpointType.Snapshot, CheckpointType.FoldOver)] CheckpointType checkpointType, [Values(1000)] int iterations, [Values] bool isAsync) { diff --git a/libs/storage/Tsavorite/cs/test/RecoveryChecks.cs b/libs/storage/Tsavorite/cs/test/RecoveryChecks.cs index 6d51fc11df..6bfe90fde8 100644 --- a/libs/storage/Tsavorite/cs/test/RecoveryChecks.cs +++ b/libs/storage/Tsavorite/cs/test/RecoveryChecks.cs @@ -87,7 +87,9 @@ public class RecoveryCheck1Tests : RecoveryCheckBase [Category("CheckpointRestore")] [Category("Smoke")] - public async ValueTask RecoveryCheck1([Values] CheckpointType checkpointType, [Values] bool isAsync, [Values] bool useReadCache, [Values(1L << 13, 1L << 16)] long indexSize) + public async ValueTask RecoveryCheck1( + [Values(CheckpointType.Snapshot, CheckpointType.FoldOver)] CheckpointType checkpointType, + [Values] bool isAsync, [Values] bool useReadCache, [Values(1L << 13, 1L << 16)] long indexSize) { using var store1 = new TsavoriteKV(new() { @@ -184,7 +186,9 @@ public class RecoveryCheck2Tests : RecoveryCheckBase [Test] [Category("TsavoriteKV"), Category("CheckpointRestore")] - public async ValueTask RecoveryCheck2([Values] CheckpointType checkpointType, [Values] bool isAsync, [Values] bool useReadCache, [Values(1L << 13, 1L << 16)] long indexSize) + public async ValueTask RecoveryCheck2( + [Values(CheckpointType.Snapshot, CheckpointType.FoldOver)] CheckpointType checkpointType, + [Values] bool isAsync, [Values] bool useReadCache, [Values(1L << 13, 1L << 16)] long indexSize) { using var store1 = new TsavoriteKV(new() { @@ -273,7 +277,9 @@ public async ValueTask RecoveryCheck2([Values] CheckpointType checkpointType, [V [Test] [Category("TsavoriteKV"), Category("CheckpointRestore")] - public void RecoveryCheck2Repeated([Values] CheckpointType checkpointType) + public void RecoveryCheck2Repeated( + [Values(CheckpointType.Snapshot, CheckpointType.FoldOver)] CheckpointType checkpointType + ) { Guid token = default; @@ -326,7 +332,9 @@ public void RecoveryCheck2Repeated([Values] CheckpointType checkpointType) [Test] [Category("TsavoriteKV"), Category("CheckpointRestore")] - public void RecoveryRollback([Values] CheckpointType checkpointType) + public void RecoveryRollback( + [Values(CheckpointType.Snapshot, CheckpointType.FoldOver)] CheckpointType checkpointType + ) { using var store = new TsavoriteKV(new() { @@ -478,7 +486,9 @@ public class RecoveryCheck3Tests : RecoveryCheckBase [Test] [Category("TsavoriteKV"), Category("CheckpointRestore")] - public async ValueTask RecoveryCheck3([Values] CheckpointType checkpointType, [Values] bool isAsync, [Values] bool useReadCache, [Values(1L << 13, 1L << 16)] long indexSize) + public async ValueTask RecoveryCheck3( + [Values(CheckpointType.Snapshot, CheckpointType.FoldOver)] CheckpointType checkpointType, + [Values] bool isAsync, [Values] bool useReadCache, [Values(1L << 13, 1L << 16)] long indexSize) { using var store1 = new TsavoriteKV(new() { @@ -578,7 +588,9 @@ public class RecoveryCheck4Tests : RecoveryCheckBase [Test] [Category("TsavoriteKV"), Category("CheckpointRestore")] - public async ValueTask RecoveryCheck4([Values] CheckpointType checkpointType, [Values] bool isAsync, [Values] bool useReadCache, [Values(1L << 13, 1L << 16)] long indexSize) + public async ValueTask RecoveryCheck4( + [Values(CheckpointType.Snapshot, CheckpointType.FoldOver)] CheckpointType checkpointType, + [Values] bool isAsync, [Values] bool useReadCache, [Values(1L << 13, 1L << 16)] long indexSize) { using var store1 = new TsavoriteKV(new() { @@ -682,7 +694,9 @@ public class RecoveryCheck5Tests : RecoveryCheckBase [Test] [Category("TsavoriteKV")] [Category("CheckpointRestore")] - public async ValueTask RecoveryCheck5([Values] CheckpointType checkpointType, [Values] bool isAsync, [Values] bool useReadCache, [Values(1L << 13, 1L << 16)] long indexSize) + public async ValueTask RecoveryCheck5( + [Values(CheckpointType.Snapshot, CheckpointType.FoldOver)] CheckpointType checkpointType, + [Values] bool isAsync, [Values] bool useReadCache, [Values(1L << 13, 1L << 16)] long indexSize) { using var store1 = new TsavoriteKV(new() { @@ -919,4 +933,153 @@ private async ValueTask IncrSnapshotRecoveryCheck(ICheckpointManager checkpointM _ = bc3.CompletePending(true); } } + + [TestFixture] + public class RecoveryCheckStreamingSnapshotTests : RecoveryCheckBase + { + [SetUp] + public void Setup() => BaseSetup(); + + [TearDown] + public void TearDown() => BaseTearDown(); + + public class SnapshotIterator : IStreamingSnapshotIteratorFunctions + { + readonly TsavoriteKV store2; + readonly long expectedCount; + + ClientSession session2; + BasicContext bc2; + + public SnapshotIterator(TsavoriteKV store2, long expectedCount) + { + this.store2 = store2; + this.expectedCount = expectedCount; + } + + public bool OnStart(Guid checkpointToken, long currentVersion, long targetVersion) + { + store2.SetVersion(targetVersion); + session2 = store2.NewSession(new MyFunctions()); + bc2 = session2.BasicContext; + return true; + } + + public bool Reader(ref long key, ref long value, RecordMetadata recordMetadata, long numberOfRecords) + { + _ = bc2.Upsert(ref key, ref value); + return true; + } + + public void OnException(Exception exception, long numberOfRecords) + => Assert.Fail(exception.Message); + + public void OnStop(bool completed, long numberOfRecords) + { + Assert.That(numberOfRecords, Is.EqualTo(expectedCount)); + session2.Dispose(); + } + } + + [Test] + [Category("TsavoriteKV")] + [Category("CheckpointRestore")] + [Category("Smoke")] + + public async ValueTask StreamingSnapshotBasicTest([Values] bool isAsync, [Values] bool useReadCache, [Values] bool reInsert, [Values(1L << 13, 1L << 16)] long indexSize) + { + using var store1 = new TsavoriteKV(new() + { + IndexSize = indexSize, + LogDevice = log, + MutableFraction = 1, + PageSize = 1L << 10, + MemorySize = 1L << 20, + ReadCacheEnabled = useReadCache, + CheckpointDir = TestUtils.MethodTestDir + }, StoreFunctions.Create(LongKeyComparer.Instance) + , (allocatorSettings, storeFunctions) => new(allocatorSettings, storeFunctions) + ); + + using var s1 = store1.NewSession(new MyFunctions()); + var bc1 = s1.BasicContext; + + for (long key = 0; key < (reInsert ? 800 : 1000); key++) + { + // If reInsert, we insert the wrong key during the first pass for the first 500 keys + long value = reInsert && key < 500 ? key + 1 : key; + _ = bc1.Upsert(ref key, ref value); + } + + if (reInsert) + { + store1.Log.FlushAndEvict(true); + for (long key = 0; key < 500; key++) + { + _ = bc1.Upsert(ref key, ref key); + } + for (long key = 800; key < 1000; key++) + { + _ = bc1.Upsert(ref key, ref key); + } + } + + if (useReadCache) + { + store1.Log.FlushAndEvict(true); + for (long key = 0; key < 1000; key++) + { + long output = default; + var status = bc1.Read(ref key, ref output); + if (!status.IsPending) + { + ClassicAssert.IsTrue(status.Found, $"status = {status}"); + ClassicAssert.AreEqual(key, output, $"output = {output}"); + } + } + _ = bc1.CompletePending(true); + } + + // First create the new store, we will insert into this store as part of the iterator functions on the old store + using var store2 = new TsavoriteKV(new() + { + IndexSize = indexSize, + LogDevice = log, + MutableFraction = 1, + PageSize = 1L << 10, + MemorySize = 1L << 20, + ReadCacheEnabled = useReadCache, + CheckpointDir = TestUtils.MethodTestDir + }, StoreFunctions.Create(LongKeyComparer.Instance) + , (allocatorSettings, storeFunctions) => new(allocatorSettings, storeFunctions) + ); + + // Take a streaming snapshot checkpoint of the old store + var iterator = new SnapshotIterator(store2, 1000); + var task = store1.TakeFullCheckpointAsync(CheckpointType.StreamingSnapshot, streamingSnapshotIteratorFunctions: iterator); + if (isAsync) + { + var (status, token) = await task; + } + else + { + var (status, token) = task.AsTask().GetAwaiter().GetResult(); + } + + // Verify that the new store has all the records + using var s2 = store2.NewSession(new MyFunctions()); + var bc2 = s2.BasicContext; + for (long key = 0; key < 1000; key++) + { + long output = default; + var status = bc2.Read(ref key, ref output); + if (!status.IsPending) + { + ClassicAssert.IsTrue(status.Found, $"status = {status}"); + ClassicAssert.AreEqual(key, output, $"output = {output}"); + } + } + _ = bc2.CompletePending(true); + } + } } \ No newline at end of file diff --git a/libs/storage/Tsavorite/cs/test/SimpleRecoveryTest.cs b/libs/storage/Tsavorite/cs/test/SimpleRecoveryTest.cs index 27c1886f99..0ca58a3bd5 100644 --- a/libs/storage/Tsavorite/cs/test/SimpleRecoveryTest.cs +++ b/libs/storage/Tsavorite/cs/test/SimpleRecoveryTest.cs @@ -59,7 +59,9 @@ public void TearDown() [Test] [Category("TsavoriteKV"), Category("CheckpointRestore")] - public async ValueTask PageBlobSimpleRecoveryTest([Values] CheckpointType checkpointType, [Values] CompletionSyncMode completionSyncMode, [Values] bool testCommitCookie) + public async ValueTask PageBlobSimpleRecoveryTest( + [Values(CheckpointType.Snapshot, CheckpointType.FoldOver)] CheckpointType checkpointType, + [Values] CompletionSyncMode completionSyncMode, [Values] bool testCommitCookie) { IgnoreIfNotRunningAzureTests(); checkpointManager = new DeviceLogCommitCheckpointManager( @@ -74,7 +76,10 @@ public async ValueTask PageBlobSimpleRecoveryTest([Values] CheckpointType checkp [Category("CheckpointRestore")] [Category("Smoke")] - public async ValueTask LocalDeviceSimpleRecoveryTest([Values] CheckpointType checkpointType, [Values] CompletionSyncMode completionSyncMode, [Values] bool testCommitCookie) + public async ValueTask LocalDeviceSimpleRecoveryTest( + [Values(CheckpointType.Snapshot, CheckpointType.FoldOver)] CheckpointType checkpointType, + [Values] CompletionSyncMode completionSyncMode, + [Values] bool testCommitCookie) { checkpointManager = new DeviceLogCommitCheckpointManager( new LocalStorageNamedDeviceFactory(), @@ -85,7 +90,10 @@ public async ValueTask LocalDeviceSimpleRecoveryTest([Values] CheckpointType che [Test] [Category("TsavoriteKV"), Category("CheckpointRestore")] - public async ValueTask SimpleRecoveryTest1([Values] CheckpointType checkpointType, [Values] CompletionSyncMode completionSyncMode, [Values] bool testCommitCookie) + public async ValueTask SimpleRecoveryTest1( + [Values(CheckpointType.Snapshot, CheckpointType.FoldOver)] CheckpointType checkpointType, + [Values] CompletionSyncMode completionSyncMode, + [Values] bool testCommitCookie) { await SimpleRecoveryTest1_Worker(checkpointType, completionSyncMode, testCommitCookie); } @@ -184,7 +192,9 @@ private async ValueTask SimpleRecoveryTest1_Worker(CheckpointType checkpointType [Test] [Category("TsavoriteKV"), Category("CheckpointRestore")] - public async ValueTask SimpleRecoveryTest2([Values] CheckpointType checkpointType, [Values] CompletionSyncMode completionSyncMode) + public async ValueTask SimpleRecoveryTest2( + [Values(CheckpointType.Snapshot, CheckpointType.FoldOver)] CheckpointType checkpointType, + [Values] CompletionSyncMode completionSyncMode) { checkpointManager = new DeviceLogCommitCheckpointManager(new LocalStorageNamedDeviceFactory(), new DefaultCheckpointNamingScheme(Path.Join(MethodTestDir, "checkpoints4")), false); log = Devices.CreateLogDevice(Path.Join(MethodTestDir, "SimpleRecoveryTest2.log"), deleteOnClose: true); diff --git a/playground/CommandInfoUpdater/SupportedCommand.cs b/playground/CommandInfoUpdater/SupportedCommand.cs index 5e79dbe4d0..c40c688f96 100644 --- a/playground/CommandInfoUpdater/SupportedCommand.cs +++ b/playground/CommandInfoUpdater/SupportedCommand.cs @@ -291,7 +291,12 @@ public class SupportedCommand new("ZSCORE", RespCommand.ZSCORE), new("EVAL", RespCommand.EVAL), new("EVALSHA", RespCommand.EVALSHA), - new("SCRIPT", RespCommand.SCRIPT), + new("SCRIPT", RespCommand.SCRIPT, + [ + new("SCRIPT|EXISTS", RespCommand.SCRIPT_EXISTS), + new("SCRIPT|FLUSH", RespCommand.SCRIPT_FLUSH), + new("SCRIPT|LOAD", RespCommand.SCRIPT_LOAD), + ]) ]; static readonly Lazy> LazySupportedCommandsMap = diff --git a/playground/Embedded.perftest/EmbeddedRespServer.cs b/playground/Embedded.perftest/EmbeddedRespServer.cs index 9a9f41aa49..2506fd9df7 100644 --- a/playground/Embedded.perftest/EmbeddedRespServer.cs +++ b/playground/Embedded.perftest/EmbeddedRespServer.cs @@ -36,7 +36,7 @@ public EmbeddedRespServer(GarnetServerOptions opts, ILoggerFactory loggerFactory /// A new RESP server session internal RespServerSession GetRespSession() { - return new RespServerSession(0, new DummyNetworkSender(), storeWrapper, null, null, false); + return new RespServerSession(0, new DummyNetworkSender(), storeWrapper, null, null, true); } } } \ No newline at end of file diff --git a/test/Garnet.test.cluster/RedirectTests/BaseCommand.cs b/test/Garnet.test.cluster/RedirectTests/BaseCommand.cs index e256ede56c..26e41a89ca 100644 --- a/test/Garnet.test.cluster/RedirectTests/BaseCommand.cs +++ b/test/Garnet.test.cluster/RedirectTests/BaseCommand.cs @@ -3,8 +3,11 @@ using System; using System.Collections.Generic; +using System.Linq; +using System.Security.Cryptography; using System.Text; using Garnet.common; +using StackExchange.Redis; namespace Garnet.test.cluster { @@ -91,6 +94,16 @@ public BaseCommand() /// public abstract ArraySegment[] SetupSingleSlotRequest(); + /// + /// Setup before command is run. + /// + /// Each segment is run via SE.Redis's . + /// + /// Runs once per test. + /// + public virtual ArraySegment[] Initialize() + => []; + /// /// Generate a list of keys that hash to a single slot /// @@ -1658,6 +1671,45 @@ public override ArraySegment[] SetupSingleSlotRequest() return setup; } } + + internal class EVALSHA : BaseCommand + { + private const string SCRIPT = "return KEYS[1]"; + + public override bool IsArrayCommand => true; + public override bool ArrayResponse => false; + public override string Command => nameof(EVALSHA); + + private string hash; + + internal EVALSHA() + { + var hashBytes = SHA1.HashData(Encoding.UTF8.GetBytes(SCRIPT)); + hash = string.Join("", hashBytes.Select(static x => x.ToString("X2"))); + } + + public override string[] GetSingleSlotRequest() + { + var ssk = GetSingleSlotKeys; + return [hash, "3", ssk[0], ssk[1], ssk[2]]; + } + + public override string[] GetCrossSlotRequest() + { + var csk = GetCrossSlotKeys; + return [hash, "3", csk[0], csk[1], csk[2]]; + } + + public override ArraySegment[] Initialize() + => [new ArraySegment(["SCRIPT", "LOAD", SCRIPT])]; + + public override ArraySegment[] SetupSingleSlotRequest() + { + var ssk = GetSingleSlotKeys; + var setup = new ArraySegment[] { new ArraySegment(["EVALSHA", hash, "3", ssk[1], ssk[2], ssk[3]]) }; + return setup; + } + } #endregion #region GeoCommands diff --git a/test/Garnet.test.cluster/RedirectTests/ClusterSlotVerificationTests.cs b/test/Garnet.test.cluster/RedirectTests/ClusterSlotVerificationTests.cs index 59c6bb488a..f603176c78 100644 --- a/test/Garnet.test.cluster/RedirectTests/ClusterSlotVerificationTests.cs +++ b/test/Garnet.test.cluster/RedirectTests/ClusterSlotVerificationTests.cs @@ -67,6 +67,7 @@ public class ClusterSlotVerificationTests new SINTER(), new LMOVE(), new EVAL(), + new EVALSHA(), new LPUSH(), new LPOP(), new LMPOP(), @@ -311,12 +312,15 @@ public virtual void OneTimeTearDown() [TestCase("WATCHMS")] [TestCase("WATCHOS")] [TestCase("SINTERCARD")] + [TestCase("EVALSHA")] public void ClusterCLUSTERDOWNTest(string commandName) { var requestNodeIndex = otherIndex; var dummyCommand = new DummyCommand(commandName); ClassicAssert.IsTrue(TestCommands.TryGetValue(dummyCommand, out var command), "Command not found"); + Initialize(command); + for (var i = 0; i < iterations; i++) SERedisClusterDown(command); @@ -395,6 +399,7 @@ void GarnetClientSessionClusterDown(BaseCommand command) [TestCase("SINTER")] [TestCase("LMOVE")] [TestCase("EVAL")] + [TestCase("EVALSHA")] [TestCase("LPUSH")] [TestCase("LPOP")] [TestCase("LMPOP")] @@ -460,6 +465,8 @@ public void ClusterOKTest(string commandName) var dummyCommand = new DummyCommand(commandName); ClassicAssert.IsTrue(TestCommands.TryGetValue(dummyCommand, out var command), "Command not found"); + Initialize(command); + for (var i = 0; i < iterations; i++) SERedisOKTest(command); @@ -549,6 +556,7 @@ void GarnetClientSessionOK(BaseCommand command) [TestCase("SINTER")] [TestCase("LMOVE")] [TestCase("EVAL")] + [TestCase("EVALSHA")] [TestCase("LPUSH")] [TestCase("LPOP")] [TestCase("LMPOP")] @@ -614,6 +622,8 @@ public void ClusterCROSSSLOTTest(string commandName) var dummyCommand = new DummyCommand(commandName); ClassicAssert.IsTrue(TestCommands.TryGetValue(dummyCommand, out var command), "Command not found"); + Initialize(command); + for (var i = 0; i < iterations; i++) SERedisCrossslotTest(command); @@ -754,6 +764,7 @@ void GarnetClientSessionCrossslotTest(BaseCommand command) [TestCase("WATCHMS")] [TestCase("WATCHOS")] [TestCase("SINTERCARD")] + [TestCase("EVALSHA")] public void ClusterMOVEDTest(string commandName) { var requestNodeIndex = targetIndex; @@ -762,6 +773,8 @@ public void ClusterMOVEDTest(string commandName) var dummyCommand = new DummyCommand(commandName); ClassicAssert.IsTrue(TestCommands.TryGetValue(dummyCommand, out var command), "Command not found"); + Initialize(command); + for (var i = 0; i < iterations; i++) SERedisMOVEDTest(command); @@ -907,6 +920,7 @@ void GarnetClientSessionMOVEDTest(BaseCommand command) [TestCase("WATCHMS")] [TestCase("WATCHOS")] [TestCase("SINTERCARD")] + [TestCase("EVALSHA")] public void ClusterASKTest(string commandName) { var requestNodeIndex = sourceIndex; @@ -914,6 +928,9 @@ public void ClusterASKTest(string commandName) var port = context.clusterTestUtils.GetPortFromNodeIndex(targetIndex); var dummyCommand = new DummyCommand(commandName); ClassicAssert.IsTrue(TestCommands.TryGetValue(dummyCommand, out var command), "Command not found"); + + Initialize(command); + ConfigureSlotForMigration(); try @@ -1082,6 +1099,9 @@ public void ClusterTRYAGAINTest(string commandName) var requestNodeIndex = sourceIndex; var dummyCommand = new DummyCommand(commandName); ClassicAssert.IsTrue(TestCommands.TryGetValue(dummyCommand, out var command), "Command not found"); + + Initialize(command); + for (var i = 0; i < iterations; i++) SERedisTRYAGAINTest(command); @@ -1131,5 +1151,18 @@ void SERedisTRYAGAINTest(BaseCommand command) Assert.Fail($"Should not reach here. Command: {command.Command}"); } } + + private void Initialize(BaseCommand cmd) + { + var server = context.clusterTestUtils.GetServer(sourceIndex); + + foreach (var initCmd in cmd.Initialize()) + { + var c = initCmd.Array[initCmd.Offset]; + var rest = initCmd.Array.AsSpan().Slice(initCmd.Offset + 1, initCmd.Count - 1).ToArray(); + + _ = server.Execute(c, rest); + } + } } } \ No newline at end of file diff --git a/test/Garnet.test/LuaScriptTests.cs b/test/Garnet.test/LuaScriptTests.cs index a52dcbe158..3d326f2a34 100644 --- a/test/Garnet.test/LuaScriptTests.cs +++ b/test/Garnet.test/LuaScriptTests.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. using System; +using System.Linq; using System.Text; using System.Threading.Tasks; using NUnit.Framework; @@ -456,5 +457,86 @@ public void ComplexLuaTest3() } } } + + [Test] + public void ScriptExistsErrors() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + var exc = ClassicAssert.Throws(() => db.Execute("SCRIPT", "EXISTS")); + ClassicAssert.AreEqual("ERR wrong number of arguments for 'script|exists' command", exc.Message); + } + + [Test] + public void ScriptFlushErrors() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + // > 1 args + { + var exc = ClassicAssert.Throws(() => db.Execute("SCRIPT", "FLUSH", "ASYNC", "BAR")); + ClassicAssert.AreEqual("ERR SCRIPT FLUSH only support SYNC|ASYNC option", exc.Message); + } + + // 1 arg, but not ASYNC or SYNC + { + var exc = ClassicAssert.Throws(() => db.Execute("SCRIPT", "FLUSH", "NOW")); + ClassicAssert.AreEqual("ERR SCRIPT FLUSH only support SYNC|ASYNC option", exc.Message); + } + } + + [Test] + public void ScriptLoadErrors() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + // 0 args + { + var exc = ClassicAssert.Throws(() => db.Execute("SCRIPT", "LOAD")); + ClassicAssert.AreEqual("ERR wrong number of arguments for 'script|load' command", exc.Message); + } + + // > 1 args + { + var exc = ClassicAssert.Throws(() => db.Execute("SCRIPT", "LOAD", "return 'foo'", "return 'bar'")); + ClassicAssert.AreEqual("ERR wrong number of arguments for 'script|load' command", exc.Message); + } + } + + [Test] + public void ScriptExistsMultiple() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var server = redis.GetServers().First(); + + var hashBytes = server.ScriptLoad("return 'foo'"); + + // upper hash + { + var hash = string.Join("", hashBytes.Select(static x => x.ToString("X2"))); + + var exists = (RedisValue[])server.Execute("SCRIPT", "EXISTS", hash, "foo", "bar"); + + ClassicAssert.AreEqual(3, exists.Length); + ClassicAssert.AreEqual(1, (long)exists[0]); + ClassicAssert.AreEqual(0, (long)exists[1]); + ClassicAssert.AreEqual(0, (long)exists[2]); + } + + // lower hash + { + var hash = string.Join("", hashBytes.Select(static x => x.ToString("x2"))); + + var exists = (RedisValue[])server.Execute("SCRIPT", "EXISTS", hash, "foo", "bar"); + + ClassicAssert.AreEqual(3, exists.Length); + ClassicAssert.AreEqual(1, (long)exists[0]); + ClassicAssert.AreEqual(0, (long)exists[1]); + ClassicAssert.AreEqual(0, (long)exists[2]); + } + } } } \ No newline at end of file diff --git a/test/Garnet.test/Resp/ACL/RespCommandTests.cs b/test/Garnet.test/Resp/ACL/RespCommandTests.cs index 63f38dac7c..cda4ccae7e 100644 --- a/test/Garnet.test/Resp/ACL/RespCommandTests.cs +++ b/test/Garnet.test/Resp/ACL/RespCommandTests.cs @@ -83,7 +83,7 @@ public void AllCommandsCovered() ClassicAssert.IsTrue(RespCommandsInfo.TryGetRespCommandNames(out IReadOnlySet advertisedCommands), "Couldn't get advertised RESP commands"); // TODO: See if these commands could be identified programmatically - IEnumerable withOnlySubCommands = ["ACL", "CLIENT", "CLUSTER", "CONFIG", "LATENCY", "MEMORY", "MODULE", "PUBSUB"]; + IEnumerable withOnlySubCommands = ["ACL", "CLIENT", "CLUSTER", "CONFIG", "LATENCY", "MEMORY", "MODULE", "PUBSUB", "SCRIPT"]; IEnumerable notCoveredByACLs = allInfo.Where(static x => x.Value.Flags.HasFlag(RespCommandFlags.NoAuth)).Select(static kv => kv.Key); // Check tests against RespCommandsInfo @@ -2493,21 +2493,71 @@ async Task DoEvalShaAsync(GarnetClient client) } [Test] - public async Task ScriptACLsAsync() + public async Task ScriptLoadACLsAsync() { await CheckCommandsAsync( - "SCRIPT", - [DoScriptAsync], - knownCategories: ["slow"] + "SCRIPT LOAD", + [DoScriptLoadAsync] ); - async Task DoScriptAsync(GarnetClient client) + async Task DoScriptLoadAsync(GarnetClient client) { string res = await client.ExecuteForStringResultAsync("SCRIPT", ["LOAD", "return 'OK'"]); ClassicAssert.AreEqual("57ade87c8731f041ecac85aba56623f8af391fab", (string)res); } } + [Test] + public async Task ScriptExistsACLsAsync() + { + await CheckCommandsAsync( + "SCRIPT EXISTS", + [DoScriptExistsSingleAsync, DoScriptExistsMultiAsync] + ); + + async Task DoScriptExistsSingleAsync(GarnetClient client) + { + string[] res = await client.ExecuteForStringArrayResultAsync("SCRIPT", ["EXISTS", "57ade87c8731f041ecac85aba56623f8af391fab"]); + ClassicAssert.AreEqual(1, res.Length); + ClassicAssert.IsTrue(res[0] == "1" || res[0] == "0"); + } + + async Task DoScriptExistsMultiAsync(GarnetClient client) + { + string[] res = await client.ExecuteForStringArrayResultAsync("SCRIPT", ["EXISTS", "57ade87c8731f041ecac85aba56623f8af391fab", "57ade87c8731f041ecac85aba56623f8af391fab"]); + ClassicAssert.AreEqual(2, res.Length); + ClassicAssert.IsTrue(res[0] == "1" || res[0] == "0"); + ClassicAssert.AreEqual(res[0], res[1]); + } + } + + [Test] + public async Task ScriptFlushACLsAsync() + { + await CheckCommandsAsync( + "SCRIPT FLUSH", + [DoScriptFlushAsync, DoScriptFlushSyncAsync, DoScriptFlushAsyncAsync] + ); + + async Task DoScriptFlushAsync(GarnetClient client) + { + string res = await client.ExecuteForStringResultAsync("SCRIPT", ["FLUSH"]); + ClassicAssert.AreEqual("OK", res); + } + + async Task DoScriptFlushSyncAsync(GarnetClient client) + { + string res = await client.ExecuteForStringResultAsync("SCRIPT", ["FLUSH", "SYNC"]); + ClassicAssert.AreEqual("OK", res); + } + + async Task DoScriptFlushAsyncAsync(GarnetClient client) + { + string res = await client.ExecuteForStringResultAsync("SCRIPT", ["FLUSH", "ASYNC"]); + ClassicAssert.AreEqual("OK", res); + } + } + [Test] public async Task DBSizeACLsAsync() { diff --git a/test/Garnet.test/RespCustomCommandTests.cs b/test/Garnet.test/RespCustomCommandTests.cs index 5d238af79d..cef9353517 100644 --- a/test/Garnet.test/RespCustomCommandTests.cs +++ b/test/Garnet.test/RespCustomCommandTests.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Linq; @@ -1088,5 +1089,195 @@ public void CustomProcedureInvokingInvalidCommandTest() var result = db.Execute("PROCINVALIDCMD", "key"); ClassicAssert.AreEqual("OK", (string)result); } + + [Test] + [TestCase(true)] + [TestCase(false)] + public void MultiRegisterCommandTest(bool sync) + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + var regCount = 24; + var regCmdTasks = new Task[regCount]; + for (var i = 0; i < regCount; i++) + { + var idx = i; + regCmdTasks[i] = new Task(() => server.Register.NewCommand($"SETIFPM{idx + 1}", CommandType.ReadModifyWrite, new SetIfPMCustomCommand(), + new RespCommandsInfo { Arity = 4 })); + } + + for (var i = 0; i < regCount; i++) + { + if (sync) + { + regCmdTasks[i].RunSynchronously(); + } + else + { + regCmdTasks[i].Start(); + } + } + + if (!sync) Task.WhenAll(regCmdTasks); + + for (var i = 0; i < regCount; i++) + { + var key = $"mykey{i + 1}"; + var origValue = "foovalue0"; + db.StringSet(key, origValue); + + var newValue1 = "foovalue1"; + db.Execute($"SETIFPM{i + 1}", key, newValue1, "foo"); + + // This conditional set should pass (prefix matches) + string retValue = db.StringGet(key); + ClassicAssert.AreEqual(newValue1, retValue); + } + } + + [Test] + [TestCase(true)] + [TestCase(false)] + public void MultiRegisterSubCommandTest(bool sync) + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + var factory = new MyDictFactory(); + server.Register.NewCommand("MYDICTGET", CommandType.Read, factory, new MyDictGet(), new RespCommandsInfo { Arity = 3 }); + + // Only able to register 31 sub-commands, try to register 32 + var regCount = 32; + var failedTaskIdAndMessage = new ConcurrentBag<(int, string)>(); + var regCmdTasks = new Task[regCount]; + for (var i = 0; i < regCount; i++) + { + var idx = i; + regCmdTasks[i] = new Task(() => + { + try + { + server.Register.NewCommand($"MYDICTSET{idx + 1}", + CommandType.ReadModifyWrite, factory, new MyDictSet(), new RespCommandsInfo { Arity = 4 }); + } + catch (Exception e) + { + failedTaskIdAndMessage.Add((idx, e.Message)); + } + }); + } + + for (var i = 0; i < regCount; i++) + { + if (sync) + { + regCmdTasks[i].RunSynchronously(); + } + else + { + regCmdTasks[i].Start(); + } + } + + if (!sync) Task.WaitAll(regCmdTasks); + + // Exactly one registration should fail + ClassicAssert.AreEqual(1, failedTaskIdAndMessage.Count); + failedTaskIdAndMessage.TryTake(out var failedTaskResult); + + var failedTaskId = failedTaskResult.Item1; + var failedTaskMessage = failedTaskResult.Item2; + ClassicAssert.AreEqual("Out of registration space", failedTaskMessage); + + var mainkey = "key"; + + // Check that all registrations worked except the failed one + for (var i = 0; i < regCount; i++) + { + if (i == failedTaskId) continue; + var key1 = $"mykey{i + 1}"; + var value1 = $"foovalue{i + 1}"; + db.Execute($"MYDICTSET{i + 1}", mainkey, key1, value1); + + var retValue = db.Execute("MYDICTGET", mainkey, key1); + ClassicAssert.AreEqual(value1, (string)retValue); + } + } + + [Test] + public void MultiRegisterTxnTest() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + var regCount = byte.MaxValue + 1; + for (var i = 0; i < regCount; i++) + { + server.Register.NewTransactionProc($"GETTWOKEYSNOTXN{i + 1}", () => new GetTwoKeysNoTxn(), new RespCommandsInfo { Arity = 3 }); + } + + try + { + // This register should fail as there could only be byte.MaxValue + 1 transactions registered + server.Register.NewTransactionProc($"GETTWOKEYSNOTXN{byte.MaxValue + 3}", () => new GetTwoKeysNoTxn(), new RespCommandsInfo { Arity = 3 }); + Assert.Fail(); + } + catch (Exception e) + { + ClassicAssert.AreEqual("Out of registration space", e.Message); + } + + for (var i = 0; i < regCount; i++) + { + var readkey1 = $"readkey{i + 1}.1"; + var value1 = $"foovalue{i + 1}.1"; + db.StringSet(readkey1, value1); + + var readkey2 = $"readkey{i + 1}.2"; + var value2 = $"foovalue{i + 1}.2"; + db.StringSet(readkey2, value2); + + var result = db.Execute($"GETTWOKEYSNOTXN{i + 1}", readkey1, readkey2); + + ClassicAssert.AreEqual(value1, ((string[])result)?[0]); + ClassicAssert.AreEqual(value2, ((string[])result)?[1]); + } + } + + [Test] + public void MultiRegisterProcTest() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + var regCount = byte.MaxValue + 1; + for (var i = 0; i < regCount; i++) + { + server.Register.NewProcedure($"SUM{i + 1}", () => new Sum()); + } + + try + { + // This register should fail as there could only be byte.MaxValue + 1 procedures registered + server.Register.NewProcedure($"SUM{byte.MaxValue + 3}", () => new Sum()); + Assert.Fail(); + } + catch (Exception e) + { + ClassicAssert.AreEqual("Out of registration space", e.Message); + } + + db.StringSet("key1", "10"); + db.StringSet("key2", "35"); + db.StringSet("key3", "20"); + + for (var i = 0; i < regCount; i++) + { + // Include non-existent and string keys as well + var retValue = db.Execute($"SUM{i + 1}", "key1", "key2", "key3", "key4"); + ClassicAssert.AreEqual("65", retValue.ToString()); + } + } } } \ No newline at end of file