From dd99467e0d5eeab4b296ba79e94e34d8bfb39e89 Mon Sep 17 00:00:00 2001 From: ShikiSuen Date: Sun, 26 Jan 2025 14:40:07 +0800 Subject: [PATCH] Use Dijkstra in lieu of DAG as the default walking algorithm. --- Megrez/src/0_CSharpExtensions.cs | 167 +++++++++++++++++++++++ Megrez/src/1_Compositor.cs | 4 +- Megrez/src/2_Walker.cs | 222 ++++++++++++++----------------- Megrez/src/5_Node.cs | 26 ++-- 4 files changed, 282 insertions(+), 137 deletions(-) diff --git a/Megrez/src/0_CSharpExtensions.cs b/Megrez/src/0_CSharpExtensions.cs index c73d780..1e03d40 100644 --- a/Megrez/src/0_CSharpExtensions.cs +++ b/Megrez/src/0_CSharpExtensions.cs @@ -11,6 +11,7 @@ using System.Collections.Generic; using System.Globalization; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; namespace Megrez { @@ -118,4 +119,170 @@ public EnumeratedItem(int offset, T value) { Value = value; } } + +// MARK: - HybridPriorityQueue + +/// +/// 針對 Sandy Bridge 架構最佳化的混合優先佇列實作。 +/// +public class HybridPriorityQueue + where T : IComparable { + // 考慮 Sandy Bridge 的 L1 快取大小,調整閾值以符合 32KB L1D 快取行為。 + private const int ArrayThreshold = 12; // 增加閾值以更好地利用快取行。 + private const int InitialCapacity = 16; // 預設容量設為 2 的冪次以優化記憶體對齊。 + private T[] _storage; // 改用陣列以減少記憶體間接引用。 + private int _count; // 追蹤實際元素數量。 + private readonly bool _isReversed; + private bool _usingArray; + + public HybridPriorityQueue(bool reversed = false) { + _isReversed = reversed; + _storage = new T[InitialCapacity]; + _count = 0; + _usingArray = true; + } + + /// + /// 取得佇列中的元素數量。 + /// + public int Count => _count; + + /// + /// 檢查佇列是否為空。 + /// + public bool IsEmpty => _count == 0; + + public void Enqueue(T element) { + // 確保容量足夠 + if (_count == _storage.Length) { + Array.Resize(ref _storage, _storage.Length * 2); + } + + if (_usingArray) { + if (_count >= ArrayThreshold) { + SwitchToHeap(); + _storage[_count++] = element; + HeapifyUp(_count - 1); + return; + } + + // 使用二分搜尋找到插入點。 + int insertIndex = FindInsertionPoint(element); + // 手動移動元素以避免使用 Array.Copy(減少函數呼叫開銷)。 + for (int i = _count; i > insertIndex; i--) { + _storage[i] = _storage[i - 1]; + } + _storage[insertIndex] = element; + _count++; + } else { + _storage[_count] = element; + HeapifyUp(_count++); + } + } + + public T? Dequeue() { + if (_count == 0) return default; + + T result = _storage[0]; + _count--; + + if (_usingArray) { + // 手動移動元素以避免使用 Array.Copy。 + for (int i = 0; i < _count; i++) { + _storage[i] = _storage[i + 1]; + } + return result; + } + + // 堆積模式。 + _storage[0] = _storage[_count]; + if (_count > 0) HeapifyDown(0); + return result; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private int FindInsertionPoint(T element) { + int left = 0; + int right = _count; + + // 展開循環以提高分支預測效率。 + while (right - left > 1) { + int mid = (left + right) >> 1; + int midStorage = element.CompareTo(_storage[mid]); + if (_isReversed ? midStorage > 0 : midStorage < 0) { + right = mid; + } else { + left = mid; + } + } + + // 處理邊界情況。 + int leftStorage = element.CompareTo(_storage[left]); + bool marginCondition = _isReversed ? leftStorage <= 0 : leftStorage >= 0; + return left < _count && marginCondition ? left + 1 : left; + } + + private void SwitchToHeap() { + _usingArray = false; + // 就地轉換為堆積,使用更有效率的方式。 + for (int i = (_count >> 1) - 1; i >= 0; i--) { + HeapifyDown(i); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void HeapifyUp(int index) { + T item = _storage[index]; + while (index > 0) { + int parentIndex = (index - 1) >> 1; + T parent = _storage[parentIndex]; + if (Compare(item, parent) >= 0) break; + _storage[index] = parent; + index = parentIndex; + } + _storage[index] = item; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void HeapifyDown(int index) { + T item = _storage[index]; + int half = _count >> 1; + + while (index < half) { + int leftChild = (index << 1) + 1; + int rightChild = leftChild + 1; + int bestChild = leftChild; + + T leftChildItem = _storage[leftChild]; + + if (rightChild < _count) { + T rightChildItem = _storage[rightChild]; + if (Compare(rightChildItem, leftChildItem) < 0) { + bestChild = rightChild; + leftChildItem = rightChildItem; + } + } + + if (Compare(item, leftChildItem) <= 0) break; + + _storage[index] = leftChildItem; + index = bestChild; + } + _storage[index] = item; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private int Compare(T a, T b) => _isReversed ? b.CompareTo(a) : a.CompareTo(b); + + /// + /// 清空佇列。 + /// + public void Clear() { + _count = 0; + _usingArray = true; + if (_storage.Length > InitialCapacity) { + _storage = new T[InitialCapacity]; + } + } +} } // namespace Megrez diff --git a/Megrez/src/1_Compositor.cs b/Megrez/src/1_Compositor.cs index 30a7d68..35e1a70 100644 --- a/Megrez/src/1_Compositor.cs +++ b/Megrez/src/1_Compositor.cs @@ -427,8 +427,8 @@ private List GetJoinedKeyArray(BRange range) => /// 拿取的節點。拿不到的話就會是 null。 private Node? GetNodeAt(int location, int length, List keyArray) { location = Math.Max(Math.Min(location, Spans.Count - 1), 0); // 防呆。 - if (Spans[location].NodeOf(length) is not {} node) return null; - return (node.KeyArray.SequenceEqual(keyArray)) ? node : null; + return Spans[location].NodeOf(length) is not {} + node ? null : (node.KeyArray.SequenceEqual(keyArray)) ? node : null; } /// diff --git a/Megrez/src/2_Walker.cs b/Megrez/src/2_Walker.cs index ab11b6f..3a03153 100644 --- a/Megrez/src/2_Walker.cs +++ b/Megrez/src/2_Walker.cs @@ -4,151 +4,127 @@ // ==================== // This code is released under the MIT license (SPDX-License-Identifier: MIT) -#nullable enable +using System; using System.Collections.Generic; using System.Linq; namespace Megrez { public partial struct Compositor { /// - /// 爬軌函式,會更新當前組字器的 - /// 找到軌格陣圖內權重最大的路徑。該路徑代表了可被觀測到的最可能的隱藏事件鏈。 - /// 這裡使用 Cormen 在 2001 年出版的教材當中提出的「有向無環圖的最短路徑」的 - /// 算法來計算這種路徑。不過,這裡不是要計算距離最短的路徑,而是計算距離最長 - /// 的路徑(所以要找最大的權重),因為在對數概率下,較大的數值意味著較大的概率。 - /// 對於 G = (V, E),該算法的運行次數為 O(|V|+|E|),其中 G - /// 是一個有向無環圖。這意味著,即使軌格很大,也可以用很少的算力就可以爬軌。 - /// - /// 利用該數學方法進行輸入法智能組句的(已知可考的)最開始的案例是郭家寶(ByVoid) - /// 的《基於統計語言模型的拼音輸入法》; - /// 再後來則是 2022 年中時期劉燈的 Gramambular 2 組字引擎。 - /// + /// 爬軌函式,會以 Dijkstra 算法更新當前組字器的 walkedNodes。 + /// 該算法會在圖中尋找具有最高分數的路徑,即最可能的字詞組合。 + /// 該算法所依賴的 HybridPriorityQueue 針對 Sandy Bridge 經過最佳化處理, + /// 使得該算法在 Sandy Bridge CPU 的電腦上比 DAG 算法擁有更優的效能。 /// - /// 爬軌結果+該過程是否順利執行。 + /// 爬軌結果(已選字詞陣列)。 public List Walk() { - List result = new(); - try { - WalkedNodes.Clear(); - SortAndRelax(); - if (Spans.IsEmpty()) return result; - Node iterated = Node.LeadingNode; - while (iterated.Prev is {} itPrev) { - WalkedNodes.Insert(0, itPrev.Copy()); - iterated = itPrev; + WalkedNodes.Clear(); + if (!Spans.Any()) return new(); + + // 初期化資料結構。 + HybridPriorityQueue openSet = new(reversed: true); + HashSet visited = new(); + Dictionary bestScore = new(); + + // 初期化起始狀態。 + Node leadingNode = new(new() { "$LEADING" }, spanLength: 0, unigrams: new()); + SearchState start = new(node: leadingNode, position: 0, prev: null, distance: 0); + openSet.Enqueue(new(state: start)); + bestScore[0] = 0; + + // 追蹤最佳結果。 + SearchState? bestFinalState = null; + double bestFinalScore = double.MinValue; + + // 主要 Dijkstra 迴圈。 + while (!openSet.IsEmpty) { + if (openSet.Dequeue() is not {} currentPState) break; + + // 如果已經造訪過具有更好分數的狀態,則跳過。 + if (!visited.Add(currentPState.State)) continue; + + // 檢查是否已到達終點。 + if (currentPState.State.Position >= Keys.Count) { + if (currentPState.State.Distance > bestFinalScore) { + bestFinalScore = currentPState.State.Distance; + bestFinalState = currentPState.State; + } + continue; + } + + // 處理下一個可能的節點。 + SpanUnit currentSpan = Spans[currentPState.State.Position]; + foreach (KeyValuePair spanNeta in currentSpan.Nodes) { + int length = spanNeta.Key; + Node nextNode = spanNeta.Value; + int nextPos = currentPState.State.Position + length; + + // 計算新的權重分數。 + double newScore = currentPState.State.Distance + nextNode.Score; + + // 如果該位置已有更優的權重分數,則跳過。 + if (bestScore.TryGetValue(nextPos, out double existingScore) && existingScore >= newScore) continue; + + SearchState nextState = new(node: nextNode, position: nextPos, prev: currentPState.State, distance: newScore); + + bestScore[nextPos] = newScore; + openSet.Enqueue(new(state: nextState)); } - iterated.DestroyVertex(); - WalkedNodes.RemoveAt(0); - return WalkedNodes; - } finally { - ReinitVertexNetwork(); } - } - /// 先進行位相幾何排序、再卸勁。 - internal void SortAndRelax() { - ReinitVertexNetwork(); - List theSpans = Spans; - if (theSpans.IsEmpty()) return; - Node.TrailingNode.Distance = 0; - theSpans.Enumerated().ToList().ForEach(spanNeta => { - int location = spanNeta.Offset; - SpanUnit vertexSpan = spanNeta.Value; - vertexSpan.Nodes.Values.ToList().ForEach(node => { - int nextVertexPosition = location + node.SpanLength; - if (nextVertexPosition == theSpans.Count) { - node.Edges.Add(Node.LeadingNode); - return; - } - theSpans[nextVertexPosition].Nodes.Values.ToList().ForEach(nextVertex => { node.Edges.Add(nextVertex); }); - }); - }); + // 從最佳終止狀態重建路徑。 + if (bestFinalState == null) return new(); - Node.TrailingNode.Edges.AddRange(Spans.First().Nodes.Values); + List pathNodes = new(); + SearchState? currentState = bestFinalState; - TopoSort().Reversed().ForEach(neta => { - neta.Edges.Enumerated().ToList().ForEach(edge => { - if (neta.Edges[edge.Offset] is {} relaxV) Relax(neta, ref relaxV); - }); - }); - } + while (currentState != null) { + // 排除起始和結束的虛擬節點。 + if (!ReferenceEquals(currentState.Node, leadingNode)) { + pathNodes.Insert(0, currentState.Node); + } + currentState = currentState.Prev; + // 備註:此處不需要手動 ASAN,因為沒有參據循環(Retain Cycle)。 + } - /// - /// 摧毀所有與共用起始虛擬節點有牽涉的節點自身的 Vertex 特性資料。 - /// - internal static void ReinitVertexNetwork() { - Node.TrailingNode.DestroyVertex(); - Node.LeadingNode.DestroyVertex(); + WalkedNodes = pathNodes.Select(n => n.Copy()).ToList(); + return WalkedNodes; } - /// - /// 位相幾何排序處理時的處理狀態。 - /// - private class TopoSortState { - public int IterIndex { get; set; } + /// 用於追蹤搜尋過程中的狀態。 + private class SearchState : IEquatable { public Node Node { get; } - public TopoSortState(Node node, int iterIndex = 0) { + public int Position { get; } + public SearchState? Prev { get; } + public double Distance { get; } + + public SearchState(Node node, int position, SearchState? prev, double distance) { Node = node; - IterIndex = iterIndex; + Position = position; + Prev = prev; + Distance = distance; } - } - /// - /// 對持有單個根頂點的有向無環圖進行位相幾何排序(topological - /// sort)、且將排序結果以頂點陣列的形式給出。 - /// 這裡使用我們自己的堆棧和狀態定義實現了一個非遞迴版本, - /// 這樣我們就不會受到當前線程的堆棧大小的限制。以下是等價的原始算法。 - /// - /// void TopoSort(vertex: Vertex) { - /// vertex.Edges.ForEach ((x) => { - /// if (!vertexNode.TopoSorted) { - /// DFS(vertexNode, result); - /// vertexNode.TopoSorted = true; - /// } - /// result.Add(vertexNode); - /// }); - /// } - /// - /// 至於其遞迴版本,則類似於 Cormen 在 2001 年的著作「Introduction to Algorithms」當中的樣子。 - /// - /// 排序結果(頂點陣列)。 - private List TopoSort() { - List result = new(); - List stack = new(); - stack.Add(new(Node.TrailingNode)); - while (!stack.IsEmpty()) { - TopoSortState state = stack.Last(); - Node theNode = state.Node; - if (state.IterIndex < state.Node.Edges.Count) { - Node newNode = state.Node.Edges[state.IterIndex]; - state.IterIndex += 1; - if (!newNode.TopoSorted) { - stack.Add(new(newNode)); - continue; - } - } - theNode.TopoSorted = true; - result.Add(theNode); - stack.Remove(stack.Last()); + public bool Equals(SearchState? other) { + return other != null && ReferenceEquals(Node, other.Node) && Position == other.Position; + } + + public override bool Equals(object? obj) => Equals(obj as SearchState); + + public override int GetHashCode() { + int[] x = { Node.GetHashCode(), Position.GetHashCode() }; + return x.GetHashCode(); } - return result; } - /// - /// 卸勁函式。 - /// 「卸勁 (relax)」一詞出自 Cormen 在 2001 年的著作「Introduction to Algorithms」的 585 頁。 - /// - /// 自己就是參照頂點 (u),會在必要時成為 v (v) 的前述頂點。 - /// 基準頂點。 - /// 要影響的頂點。 - private static void Relax(Node u, ref Node v) { - // 從 u 到 w 的距離,也就是 v 的權重。 - double w = v.Score; - // 這裡計算最大權重: - // 如果 v 目前的距離值小於「u 的距離值+w(w 是 u 到 w 的距離,也就是 v 的權重)」, - // 我們就更新 v 的距離及其前述頂點。 - if (v.Distance >= u.Distance + w) return; - v.Distance = u.Distance + w; - v.Prev = u; + private record PrioritizedState : IComparable { + public SearchState State { get; } + + public PrioritizedState(SearchState state) => State = state; + + public int CompareTo(PrioritizedState? other) { + return other == null ? 1 : State.Distance.CompareTo(other.State.Distance); + } } } -} \ No newline at end of file +} diff --git a/Megrez/src/5_Node.cs b/Megrez/src/5_Node.cs index ed0bfd5..e6f2d35 100644 --- a/Megrez/src/5_Node.cs +++ b/Megrez/src/5_Node.cs @@ -141,10 +141,11 @@ public Node(Node node) { /// /// public override bool Equals(object obj) { - if (obj is not Node node) return false; - return OverridingScore == node.OverridingScore && KeyArray.SequenceEqual(node.KeyArray) && - SpanLength == node.SpanLength && Unigrams == node.Unigrams && - CurrentOverrideType == node.CurrentOverrideType && CurrentUnigramIndex == node.CurrentUnigramIndex; + return obj is not Node node + ? false + : OverridingScore == node.OverridingScore && KeyArray.SequenceEqual(node.KeyArray) && + SpanLength == node.SpanLength && Unigrams == node.Unigrams && + CurrentOverrideType == node.CurrentOverrideType && CurrentUnigramIndex == node.CurrentUnigramIndex; } /// @@ -180,10 +181,10 @@ public override int GetHashCode() { /// public double Score { get { - if (Unigrams.IsEmpty()) return 0; - return CurrentOverrideType switch { OverrideType.HighScore => OverridingScore, - OverrideType.TopUnigramScore => Unigrams.First().Score, - _ => CurrentUnigram.Score }; + return Unigrams.IsEmpty() ? 0 + : CurrentOverrideType switch { OverrideType.HighScore => OverridingScore, + OverrideType.TopUnigramScore => Unigrams.First().Score, + _ => CurrentUnigram.Score }; } } @@ -352,8 +353,8 @@ public static BRange ContextRange(this List self, int givenCursor) { // 下文按道理來講不應該會出現 nilReturn。 if (!dictPair.CursorRegionMap.TryGetValue(cursor, out int rearNodeID)) return nilReturn; if (!dictPair.RegionCursorMap.TryGetValue(rearNodeID, out int rearIndex)) return nilReturn; - if (!dictPair.RegionCursorMap.TryGetValue(rearNodeID + 1, out int frontIndex)) return nilReturn; - return new(rearIndex, frontIndex); + return !dictPair.RegionCursorMap.TryGetValue(rearNodeID + 1, out int frontIndex) ? nilReturn + : new(rearIndex, frontIndex); } /// /// 在陣列內以給定游標位置找出對應的節點。 @@ -368,8 +369,9 @@ public static BRange ContextRange(this List self, int givenCursor) { BRange range = self.ContextRange(givenCursor: cursor); outCursorPastNode = range.Upperbound; CursorMapPair dictPair = self.NodeBorderPointDictPair(); - if (!dictPair.CursorRegionMap.TryGetValue(cursor + 1, out int rearNodeID)) return null; - return self.Count - 1 >= rearNodeID ? self[rearNodeID] : null; + return !dictPair.CursorRegionMap.TryGetValue(cursor + 1, out int rearNodeID) ? null + : self.Count - 1 >= rearNodeID ? self[rearNodeID] + : null; } /// /// 在陣列內以給定游標位置找出對應的節點。