From eb4e319ebf968702ba05d650cac3295c054e3912 Mon Sep 17 00:00:00 2001 From: Patrick Hovsepian Date: Mon, 10 Jun 2024 21:27:33 -0700 Subject: [PATCH 01/11] Mmwip (#2) * seed manageR * model manager init * interface * test * tests --------- Co-authored-by: Pat Hov Co-authored-by: pat_hov --- LLama.Unittest/Constants.cs | 1 + LLama.Unittest/Model/ModelManagerTests.cs | 125 +++++++++ LLama/LLamaWeights.cs | 8 + LLama/Model/ModelManager.cs | 309 ++++++++++++++++++++++ 4 files changed, 443 insertions(+) create mode 100644 LLama.Unittest/Model/ModelManagerTests.cs create mode 100644 LLama/Model/ModelManager.cs diff --git a/LLama.Unittest/Constants.cs b/LLama.Unittest/Constants.cs index 4852a335e..d344974dc 100644 --- a/LLama.Unittest/Constants.cs +++ b/LLama.Unittest/Constants.cs @@ -4,6 +4,7 @@ namespace LLama.Unittest { internal static class Constants { + public static readonly string ModelDirectory = "Models"; public static readonly string GenerativeModelPath = "Models/llama-2-7b-chat.Q3_K_S.gguf"; public static readonly string EmbeddingModelPath = "Models/all-MiniLM-L12-v2.Q8_0.gguf"; diff --git a/LLama.Unittest/Model/ModelManagerTests.cs b/LLama.Unittest/Model/ModelManagerTests.cs new file mode 100644 index 000000000..3d80e2832 --- /dev/null +++ b/LLama.Unittest/Model/ModelManagerTests.cs @@ -0,0 +1,125 @@ +using LLama.Common; +using LLama.Model; + +namespace LLama.Unittest; + +public class ModelManagerTests +{ + private readonly ModelManager TestableModelManager; + + public ModelManagerTests() + { + TestableModelManager = new([Constants.ModelDirectory]); + } + + [Fact] + public void ModelDirectories_IsCorrect() + { + var dirs = TestableModelManager.ModelDirectories; + Assert.Single(dirs); + + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + } + + [Fact] + public void AddDirectory_DoesntDuplicate() + { + for (var i = 0; i < 10; i++) + { + TestableModelManager.AddDirectory(Constants.ModelDirectory); + TestableModelManager.AddDirectory(Path.GetFullPath(Constants.ModelDirectory)); + + var dirs = TestableModelManager.ModelDirectories; + Assert.Single(dirs); + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + } + } + + [Fact] + public void RemoveDirectory() + { + var dirs = TestableModelManager.ModelDirectories; + Assert.Single(dirs); + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + + Assert.True(TestableModelManager.RemoveDirectory(Constants.ModelDirectory)); + Assert.Empty(TestableModelManager.ModelDirectories); + Assert.Empty(TestableModelManager.ModelFileList); + } + + [Fact] + public void RemoveDirectory_DoesNotExist() + { + var dirs = TestableModelManager.ModelDirectories; + Assert.Single(dirs); + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + + Assert.False(TestableModelManager.RemoveDirectory("foo/boo/bar")); + Assert.Single(dirs); + } + + [Fact] + public void RemoveAllDirectories() + { + var dirs = TestableModelManager.ModelDirectories; + Assert.Single(dirs); + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + + TestableModelManager.RemoveAllDirectories(); + Assert.Empty(TestableModelManager.ModelDirectories); + Assert.Empty(TestableModelManager.ModelFileList); + } + + [Fact] + public void ModelFiles_IsCorrect() + { + var files = TestableModelManager.ModelFileList; + Assert.Equal(4, files.Count()); + } + + [Fact] + public void GetAvailableModelsFromDirectory() + { + var files = TestableModelManager.GetAvailableModelsFromDirectory(Constants.ModelDirectory); + Assert.Equal(4, files.Count()); + + files = TestableModelManager.ModelFileList; + Assert.Equal(4, files.Count()); + } + + [Fact] + public void TryGetModelFileMetadata_WhenExists() + { + var expectedFile = TestableModelManager.ModelFileList.First(); + var found = TestableModelManager.TryGetModelFileMetadata(expectedFile.FilePath, out var foundData); + + Assert.True(found); + Assert.Equal(expectedFile.FilePath, foundData.FilePath); + } + + [Fact] + public async void LoadModel_LoadsAndCaches() + { + var modelToLoad = TestableModelManager.ModelFileList + .First(f => f.FileName.Contains("llama-2-7b")); + + var model = await TestableModelManager.LoadModel(modelToLoad.FilePath, null!); + + Assert.Single(TestableModelManager.GetLoadedModels()); + + var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel); + Assert.True(isLoaded); + + // unload + Assert.True(TestableModelManager.UnloadModel(model.ModelName)); + + Assert.Throws(() => { + _ = model.CreateContext(new ModelParams(modelToLoad.FilePath)); + }); + } +} diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 8646e4d93..809ca1208 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -22,6 +22,14 @@ public sealed class LLamaWeights /// Be careful how you use this! public SafeLlamaModelHandle NativeHandle { get; } + /// + /// The models name as specified in it's metadata + /// + /// + public string ModelName => Metadata.TryGetValue("general.name", out var name) + ? name + : string.Empty; + /// /// Total number of tokens in vocabulary of this model /// diff --git a/LLama/Model/ModelManager.cs b/LLama/Model/ModelManager.cs new file mode 100644 index 000000000..a0b19c388 --- /dev/null +++ b/LLama/Model/ModelManager.cs @@ -0,0 +1,309 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using LLama.Common; + +namespace LLama.Model; + +/// +/// Types of supported model files +/// +public enum ModelFileType +{ + GGUF +} + +/// +/// Metadata about available models +/// +public class ModelFileMetadata +{ +#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member + public string FileName { get; init; } = string.Empty; + public string FilePath { get; init; } = string.Empty; + public ModelFileType ModelType { get; init; } + public long SizeInBytes { get; init; } = 0; +#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member +} + +/// +/// A class that helps organize and load local models +/// +public interface IModelManager +{ + // Model Directories + /// + /// Configured set of directories that are scanned to find local models + /// + /// + public IEnumerable ModelDirectories { get; } + + /// + /// Add a directory containing model files + /// + /// + public void AddDirectory(string directory); + + /// + /// Remove a directory from being scanned and having model files made available + /// + /// + /// + public bool RemoveDirectory(string directory); + + /// + /// Remove all model directories + /// + public void RemoveAllDirectories(); + + // Model Files + /// + /// Get all of the model files that are available to be loaded + /// + /// + public IEnumerable ModelFileList { get; } + + /// + /// Only get the models associated with a specific directory + /// + /// + /// The files, if any associated with a given directory + public IEnumerable GetAvailableModelsFromDirectory(string directory); + + /// + /// Get the file data for given model + /// + /// + /// + /// If a model with the given file name is present + public bool TryGetModelFileMetadata(string fileName, out ModelFileMetadata modelMeta); + + // Model Load and Unload + /// + /// Load a model file to be used for infernce + /// + /// + /// + /// + /// + /// The loaded model on success + public Task LoadModel(string modelPath, + Action? modelConfigurator = null!, + string modelId = "", + CancellationToken cancellationToken = default); + + /// + /// Unload and dispose of a model with the given id + /// + /// + /// + public bool UnloadModel(string modelId); + + /// + /// Unload all currently loaded models + /// + public void UnloadAllModels(); + + /// + /// Attempt to get a model that's expected to be loaded + /// + /// + /// + /// + public bool TryGetLoadedModel(string modeId, out LLamaWeights model); + + /// + /// Currently loaded models + /// + /// + public IEnumerable GetLoadedModels(); +} + +/// +public class ModelManager : IModelManager +{ + /// + /// Support model type files + /// + public static string[] ExpectedModelFileTypes = [ + ".gguf" + ]; + + // keys are directories, values are applicable models + private readonly Dictionary> _availableModels = []; + + // model id/alias, to loaded model + private readonly Dictionary _loadedModelCache = []; + + /// + /// Create a new model manager that seeds available models from the given directory list + /// + /// + public ModelManager(string[] directories) + { + GetModelsFromDirectories(directories); + } + + private void GetModelsFromDirectories(params string[] dirs) + { + foreach (var dir in dirs) + { + var fullDirectoryPath = Path.GetFullPath(dir); + + if (!Directory.Exists(fullDirectoryPath)) + { + Trace.TraceError($"Model directory '{fullDirectoryPath}' does not exist"); + continue; + } + + if (_availableModels.ContainsKey(fullDirectoryPath)) + { + Trace.TraceWarning($"Model directory '{fullDirectoryPath}' already probed"); + continue; + } + + // find models in current dir that are of expected type + List directoryModelFiles = []; + foreach (var file in Directory.EnumerateFiles(fullDirectoryPath)) + { + if (!ExpectedModelFileTypes.Contains(Path.GetExtension(file))) + { + continue; + } + + // expected model file + var fi = new FileInfo(file); + directoryModelFiles.Add(new ModelFileMetadata + { + FileName = fi.Name, + FilePath = fi.FullName, + ModelType = ModelFileType.GGUF, + SizeInBytes = fi.Length, + }); + } + + _availableModels.Add(fullDirectoryPath, directoryModelFiles); + } + } + + /// + public IEnumerable ModelFileList + => _availableModels.SelectMany(x => x.Value); + /// + public IEnumerable ModelDirectories + => _availableModels.Keys; + + /// + public void AddDirectory(string directory) + { + GetModelsFromDirectories(directory); + } + + /// + public bool RemoveDirectory(string directory) + { + return _availableModels.Remove(Path.GetFullPath(directory)); + } + + /// + public void RemoveAllDirectories() + { + _availableModels.Clear(); + } + + /// + public IEnumerable GetAvailableModelsFromDirectory(string directory) + { + var dirPath = Path.GetFullPath(directory); + return _availableModels.TryGetValue(dirPath, out var dirModels) + ? dirModels + : []; + } + + /// + public bool TryGetModelFileMetadata(string fileName, out ModelFileMetadata modelMeta) + { + var filePath = Path.GetFullPath(fileName); + modelMeta = ModelFileList.FirstOrDefault(f => f.FilePath == filePath)!; + return modelMeta != null; + } + + /// + public IEnumerable GetLoadedModels() + { + return _loadedModelCache.Values; + } + + /// + public bool TryGetLoadedModel(string modelId, out LLamaWeights model) + { + return _loadedModelCache.TryGetValue(modelId, out model!); + } + + /// + public async Task LoadModel(string modelPath, + Action? modelConfigurator = null!, + string modelId = "", + CancellationToken cancellationToken = default) + { + // Configure model params + var modelParams = new ModelParams(modelPath); + modelConfigurator ??= DefaultModelConfigurator; + modelConfigurator.Invoke(modelParams); + + // load and cache + var model = await LLamaWeights.LoadFromFileAsync(modelParams, cancellationToken); + if (string.IsNullOrWhiteSpace(modelId)) + { + modelId = model.ModelName; + } + _loadedModelCache.Add(modelId, model); + return model; + } + + /// + /// Updates the passed in model params with default values + /// + /// + public static void DefaultModelConfigurator(ModelParams modelParams) + { + // Reasonable defaults + modelParams.GpuLayerCount = 12; + modelParams.Seed = 1337u; + modelParams.ContextSize = 2048; + modelParams.Encoding = Encoding.UTF8; + + modelParams.UseMemoryLock = true; + modelParams.UseMemorymap = true; + + // auto detect + modelParams.Threads = null!; + modelParams.BatchThreads = null!; + } + + /// + public bool UnloadModel(string modelId) + { + if (TryGetLoadedModel(modelId, out var model)) + { + model.Dispose(); + return _loadedModelCache.Remove(modelId); + } + return false; + } + + /// + public void UnloadAllModels() + { + foreach (var model in _loadedModelCache.Values) + { + model.Dispose(); + } + _loadedModelCache.Clear(); + } +} From 4b5a9661d2d9d5ff898cb1b8ce57bcb18879c5c0 Mon Sep 17 00:00:00 2001 From: Patrick Hovsepian Date: Tue, 11 Jun 2024 19:25:35 -0700 Subject: [PATCH 02/11] Pr feedback (#3) * seed manageR * model manager init * interface * test * tests * no default configurator * Rename class * handle already disposed --------- Co-authored-by: Pat Hov Co-authored-by: pat_hov --- LLama.Unittest/Model/ModelCacheTests.cs | 168 +++++++++++++ LLama/Model/ModelCache.cs | 313 ++++++++++++++++++++++++ 2 files changed, 481 insertions(+) create mode 100644 LLama.Unittest/Model/ModelCacheTests.cs create mode 100644 LLama/Model/ModelCache.cs diff --git a/LLama.Unittest/Model/ModelCacheTests.cs b/LLama.Unittest/Model/ModelCacheTests.cs new file mode 100644 index 000000000..936eff1dd --- /dev/null +++ b/LLama.Unittest/Model/ModelCacheTests.cs @@ -0,0 +1,168 @@ +using LLama.Common; +using LLama.Model; + +namespace LLama.Unittest; + +public class ModelManagerTests +{ + private readonly ModelCache TestableModelManager; + + public ModelManagerTests() + { + TestableModelManager = new([Constants.ModelDirectory]); + } + + [Fact] + public void ModelDirectories_IsCorrect() + { + var dirs = TestableModelManager.ModelDirectories; + Assert.Single(dirs); + + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + } + + [Fact] + public void AddDirectory_DoesntDuplicate() + { + for (var i = 0; i < 10; i++) + { + TestableModelManager.AddDirectory(Constants.ModelDirectory); + TestableModelManager.AddDirectory(Path.GetFullPath(Constants.ModelDirectory)); + + var dirs = TestableModelManager.ModelDirectories; + Assert.Single(dirs); + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + } + } + + [Fact] + public void RemoveDirectory() + { + var dirs = TestableModelManager.ModelDirectories; + Assert.Single(dirs); + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + + Assert.True(TestableModelManager.RemoveDirectory(Constants.ModelDirectory)); + Assert.Empty(TestableModelManager.ModelDirectories); + Assert.Empty(TestableModelManager.ModelFileList); + } + + [Fact] + public void RemoveDirectory_DoesNotExist() + { + var dirs = TestableModelManager.ModelDirectories; + Assert.Single(dirs); + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + + Assert.False(TestableModelManager.RemoveDirectory("foo/boo/bar")); + Assert.Single(dirs); + } + + [Fact] + public void RemoveAllDirectories() + { + var dirs = TestableModelManager.ModelDirectories; + Assert.Single(dirs); + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + + TestableModelManager.RemoveAllDirectories(); + Assert.Empty(TestableModelManager.ModelDirectories); + Assert.Empty(TestableModelManager.ModelFileList); + } + + [Fact] + public void ModelFiles_IsCorrect() + { + var files = TestableModelManager.ModelFileList; + Assert.Equal(4, files.Count()); + } + + [Fact] + public void GetAvailableModelsFromDirectory() + { + var files = TestableModelManager.GetAvailableModelsFromDirectory(Constants.ModelDirectory); + Assert.Equal(4, files.Count()); + + files = TestableModelManager.ModelFileList; + Assert.Equal(4, files.Count()); + } + + [Fact] + public void TryGetModelFileMetadata_WhenExists() + { + var expectedFile = TestableModelManager.ModelFileList.First(); + var found = TestableModelManager.TryGetModelFileMetadata(expectedFile.FilePath, out var foundData); + + Assert.True(found); + Assert.Equal(expectedFile.FilePath, foundData.FilePath); + } + + [Fact] + public async void LoadModel_LoadsAndCaches() + { + var modelToLoad = TestableModelManager.ModelFileList + .First(f => f.FileName.Contains("llama-2-7b")); + + var model = await TestableModelManager.LoadModel(modelToLoad.FilePath, null!); + + Assert.Single(TestableModelManager.GetLoadedModels()); + + var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel); + Assert.True(isLoaded); + + // unload + Assert.True(TestableModelManager.UnloadModel(model.ModelName)); + + Assert.Throws(() => + { + _ = model.CreateContext(new ModelParams(modelToLoad.FilePath)); + }); + } + + [Fact] + public async void LoadModel_AlreadyLoaded_ReturnsFromCache() + { + var modelToLoad = TestableModelManager.ModelFileList + .First(f => f.FileName.Contains("llama-2-7b")); + + for (var i = 0; i < 5; i++) + { + var model = await TestableModelManager.LoadModel(modelToLoad.FilePath); + Assert.NotNull(model); + Assert.Equal("LLaMA v2", model.ModelName); + Assert.Single(TestableModelManager.GetLoadedModels()); + var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel); + Assert.True(isLoaded); + Assert.NotNull(cachedModel); + Assert.Equal("LLaMA v2", cachedModel.ModelName); + } + } + + [Fact] + public async void TryGetLoadedModel_AlreadyDisposed_ReturnsFalse() + { + var modelToLoad = TestableModelManager.ModelFileList + .First(f => f.FileName.Contains("llama-2-7b")); + + using (var model = await TestableModelManager.LoadModel(modelToLoad.FilePath)) + { + Assert.NotNull(model); + Assert.Equal("LLaMA v2", model.ModelName); + Assert.Single(TestableModelManager.GetLoadedModels()); + var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel); + Assert.True(isLoaded); + Assert.NotNull(cachedModel); + Assert.Equal("LLaMA v2", cachedModel.ModelName); + } // end scope, dispose model + + // Model is now disposed + var isDispoedLoaded = TestableModelManager.TryGetLoadedModel("LLaMA v2", out var disposedModel); + Assert.False(isDispoedLoaded); + Assert.Null(disposedModel); + } +} diff --git a/LLama/Model/ModelCache.cs b/LLama/Model/ModelCache.cs new file mode 100644 index 000000000..5cdcdde1c --- /dev/null +++ b/LLama/Model/ModelCache.cs @@ -0,0 +1,313 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using LLama.Common; + +namespace LLama.Model; + +/// +/// Types of supported model files +/// +public enum ModelFileType +{ + GGUF +} + +/// +/// Metadata about available models +/// +public class ModelFileMetadata +{ +#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member + public string FileName { get; init; } = string.Empty; + public string FilePath { get; init; } = string.Empty; + public ModelFileType ModelType { get; init; } + public long SizeInBytes { get; init; } = 0; +#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member +} + +/// +/// A class that helps organize and load local models +/// +public interface IModelCache +{ + // Model Directories + /// + /// Configured set of directories that are scanned to find local models + /// + /// + public IEnumerable ModelDirectories { get; } + + /// + /// Add a directory containing model files + /// + /// + public void AddDirectory(string directory); + + /// + /// Remove a directory from being scanned and having model files made available + /// + /// + /// + public bool RemoveDirectory(string directory); + + /// + /// Remove all model directories + /// + public void RemoveAllDirectories(); + + // Model Files + /// + /// Get all of the model files that are available to be loaded + /// + /// + public IEnumerable ModelFileList { get; } + + /// + /// Only get the models associated with a specific directory + /// + /// + /// The files, if any associated with a given directory + public IEnumerable GetAvailableModelsFromDirectory(string directory); + + /// + /// Get the file data for given model + /// + /// + /// + /// If a model with the given file name is present + public bool TryGetModelFileMetadata(string fileName, out ModelFileMetadata modelMeta); + + // Model Load and Unload + /// + /// Load a model file to be used for infernce + /// + /// + /// + /// + /// + /// The loaded model on success + public Task LoadModel(string modelPath, + Action? modelConfigurator = null!, + string modelId = "", + CancellationToken cancellationToken = default); + + /// + /// Unload and dispose of a model with the given id + /// + /// + /// + public bool UnloadModel(string modelId); + + /// + /// Unload all currently loaded models + /// + public void UnloadAllModels(); + + /// + /// Attempt to get a model that's expected to be loaded + /// + /// + /// + /// + public bool TryGetLoadedModel(string modeId, out LLamaWeights model); + + /// + /// Currently loaded models + /// + /// + public IEnumerable GetLoadedModels(); +} + +/// +public class ModelCache : IModelCache +{ + /// + /// Support model type files + /// + public static readonly string[] ExpectedModelFileTypes = [ + ".gguf" + ]; + + // keys are directories, values are applicable models + private readonly Dictionary> _availableModels = []; + + // model id/alias, to loaded model + private readonly Dictionary _loadedModelCache = []; + + /// + /// Create a new model manager that seeds available models from the given directory list + /// + /// + public ModelCache(string[] directories) + { + GetModelsFromDirectories(directories); + } + + private void GetModelsFromDirectories(params string[] dirs) + { + foreach (var dir in dirs) + { + var fullDirectoryPath = Path.GetFullPath(dir); + + if (!Directory.Exists(fullDirectoryPath)) + { + Trace.TraceError($"Model directory '{fullDirectoryPath}' does not exist"); + continue; + } + + if (_availableModels.ContainsKey(fullDirectoryPath)) + { + Trace.TraceWarning($"Model directory '{fullDirectoryPath}' already probed"); + continue; + } + + // find models in current dir that are of expected type + List directoryModelFiles = []; + foreach (var file in Directory.EnumerateFiles(fullDirectoryPath)) + { + if (!ExpectedModelFileTypes.Contains(Path.GetExtension(file))) + { + continue; + } + + // expected model file + var fi = new FileInfo(file); + directoryModelFiles.Add(new ModelFileMetadata + { + FileName = fi.Name, + FilePath = fi.FullName, + ModelType = ModelFileType.GGUF, + SizeInBytes = fi.Length, + }); + } + + _availableModels.Add(fullDirectoryPath, directoryModelFiles); + } + } + + /// + public IEnumerable ModelFileList + => _availableModels.SelectMany(x => x.Value); + /// + public IEnumerable ModelDirectories + => _availableModels.Keys; + + /// + public void AddDirectory(string directory) + { + GetModelsFromDirectories(directory); + } + + /// + public bool RemoveDirectory(string directory) + { + return _availableModels.Remove(Path.GetFullPath(directory)); + } + + /// + public void RemoveAllDirectories() + { + _availableModels.Clear(); + } + + /// + public IEnumerable GetAvailableModelsFromDirectory(string directory) + { + var dirPath = Path.GetFullPath(directory); + return _availableModels.TryGetValue(dirPath, out var dirModels) + ? dirModels + : []; + } + + /// + public bool TryGetModelFileMetadata(string fileName, out ModelFileMetadata modelMeta) + { + var filePath = Path.GetFullPath(fileName); + modelMeta = ModelFileList.FirstOrDefault(f => f.FilePath == filePath)!; + return modelMeta != null; + } + + /// + public IEnumerable GetLoadedModels() + { + return _loadedModelCache.Values; + } + + /// + public bool TryGetLoadedModel(string modelId, out LLamaWeights model) + { + var isCached = _loadedModelCache.TryGetValue(modelId, out model!); + + // Externall disposed, act like it's not in here + if (isCached && model.NativeHandle.IsClosed) + { + _ = _loadedModelCache.Remove(modelId); + isCached = false; + model = null!; + } + + return isCached; + } + + /// + public async Task LoadModel(string modelPath, + Action? modelConfigurator = null!, + string modelId = "", + CancellationToken cancellationToken = default) + { + // is the model already loaded? alias could be different but it's up to the caller to be consistent + if (!string.IsNullOrEmpty(modelId) + && TryGetLoadedModel(modelId, out var loadedModel)) + { + Trace.TraceWarning($"Model {modelId} already loaded"); + return loadedModel; + } + + // Configure model params + var modelParams = new ModelParams(modelPath); + modelConfigurator?.Invoke(modelParams); + + // load and cache + var model = await LLamaWeights.LoadFromFileAsync(modelParams, cancellationToken); + if (string.IsNullOrWhiteSpace(modelId)) + { + modelId = model.ModelName; + + // Check if cached again with alias + // TODO: Consider the case where the alias is different but the underlying model file is the same + if (TryGetLoadedModel(modelId, out loadedModel)) + { + model.Dispose(); + return loadedModel; + } + } + _loadedModelCache.Add(modelId, model); + return model; + } + + /// + public bool UnloadModel(string modelId) + { + if (TryGetLoadedModel(modelId, out var model)) + { + model.Dispose(); + return _loadedModelCache.Remove(modelId); + } + return false; + } + + /// + public void UnloadAllModels() + { + foreach (var model in _loadedModelCache.Values) + { + model.Dispose(); + } + _loadedModelCache.Clear(); + } +} From fb46f57f149d13edcaa91c68621224dbf033fd2a Mon Sep 17 00:00:00 2001 From: pat_hov Date: Tue, 11 Jun 2024 19:34:49 -0700 Subject: [PATCH 03/11] merge fix --- LLama.Unittest/Model/ModelManagerTests.cs | 125 --------- LLama/Model/ModelManager.cs | 309 ---------------------- 2 files changed, 434 deletions(-) delete mode 100644 LLama.Unittest/Model/ModelManagerTests.cs delete mode 100644 LLama/Model/ModelManager.cs diff --git a/LLama.Unittest/Model/ModelManagerTests.cs b/LLama.Unittest/Model/ModelManagerTests.cs deleted file mode 100644 index 3d80e2832..000000000 --- a/LLama.Unittest/Model/ModelManagerTests.cs +++ /dev/null @@ -1,125 +0,0 @@ -using LLama.Common; -using LLama.Model; - -namespace LLama.Unittest; - -public class ModelManagerTests -{ - private readonly ModelManager TestableModelManager; - - public ModelManagerTests() - { - TestableModelManager = new([Constants.ModelDirectory]); - } - - [Fact] - public void ModelDirectories_IsCorrect() - { - var dirs = TestableModelManager.ModelDirectories; - Assert.Single(dirs); - - var expected = dirs.First()!.Contains(Constants.ModelDirectory); - Assert.True(expected); - } - - [Fact] - public void AddDirectory_DoesntDuplicate() - { - for (var i = 0; i < 10; i++) - { - TestableModelManager.AddDirectory(Constants.ModelDirectory); - TestableModelManager.AddDirectory(Path.GetFullPath(Constants.ModelDirectory)); - - var dirs = TestableModelManager.ModelDirectories; - Assert.Single(dirs); - var expected = dirs.First()!.Contains(Constants.ModelDirectory); - Assert.True(expected); - } - } - - [Fact] - public void RemoveDirectory() - { - var dirs = TestableModelManager.ModelDirectories; - Assert.Single(dirs); - var expected = dirs.First()!.Contains(Constants.ModelDirectory); - Assert.True(expected); - - Assert.True(TestableModelManager.RemoveDirectory(Constants.ModelDirectory)); - Assert.Empty(TestableModelManager.ModelDirectories); - Assert.Empty(TestableModelManager.ModelFileList); - } - - [Fact] - public void RemoveDirectory_DoesNotExist() - { - var dirs = TestableModelManager.ModelDirectories; - Assert.Single(dirs); - var expected = dirs.First()!.Contains(Constants.ModelDirectory); - Assert.True(expected); - - Assert.False(TestableModelManager.RemoveDirectory("foo/boo/bar")); - Assert.Single(dirs); - } - - [Fact] - public void RemoveAllDirectories() - { - var dirs = TestableModelManager.ModelDirectories; - Assert.Single(dirs); - var expected = dirs.First()!.Contains(Constants.ModelDirectory); - Assert.True(expected); - - TestableModelManager.RemoveAllDirectories(); - Assert.Empty(TestableModelManager.ModelDirectories); - Assert.Empty(TestableModelManager.ModelFileList); - } - - [Fact] - public void ModelFiles_IsCorrect() - { - var files = TestableModelManager.ModelFileList; - Assert.Equal(4, files.Count()); - } - - [Fact] - public void GetAvailableModelsFromDirectory() - { - var files = TestableModelManager.GetAvailableModelsFromDirectory(Constants.ModelDirectory); - Assert.Equal(4, files.Count()); - - files = TestableModelManager.ModelFileList; - Assert.Equal(4, files.Count()); - } - - [Fact] - public void TryGetModelFileMetadata_WhenExists() - { - var expectedFile = TestableModelManager.ModelFileList.First(); - var found = TestableModelManager.TryGetModelFileMetadata(expectedFile.FilePath, out var foundData); - - Assert.True(found); - Assert.Equal(expectedFile.FilePath, foundData.FilePath); - } - - [Fact] - public async void LoadModel_LoadsAndCaches() - { - var modelToLoad = TestableModelManager.ModelFileList - .First(f => f.FileName.Contains("llama-2-7b")); - - var model = await TestableModelManager.LoadModel(modelToLoad.FilePath, null!); - - Assert.Single(TestableModelManager.GetLoadedModels()); - - var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel); - Assert.True(isLoaded); - - // unload - Assert.True(TestableModelManager.UnloadModel(model.ModelName)); - - Assert.Throws(() => { - _ = model.CreateContext(new ModelParams(modelToLoad.FilePath)); - }); - } -} diff --git a/LLama/Model/ModelManager.cs b/LLama/Model/ModelManager.cs deleted file mode 100644 index a0b19c388..000000000 --- a/LLama/Model/ModelManager.cs +++ /dev/null @@ -1,309 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Linq; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using LLama.Common; - -namespace LLama.Model; - -/// -/// Types of supported model files -/// -public enum ModelFileType -{ - GGUF -} - -/// -/// Metadata about available models -/// -public class ModelFileMetadata -{ -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - public string FileName { get; init; } = string.Empty; - public string FilePath { get; init; } = string.Empty; - public ModelFileType ModelType { get; init; } - public long SizeInBytes { get; init; } = 0; -#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member -} - -/// -/// A class that helps organize and load local models -/// -public interface IModelManager -{ - // Model Directories - /// - /// Configured set of directories that are scanned to find local models - /// - /// - public IEnumerable ModelDirectories { get; } - - /// - /// Add a directory containing model files - /// - /// - public void AddDirectory(string directory); - - /// - /// Remove a directory from being scanned and having model files made available - /// - /// - /// - public bool RemoveDirectory(string directory); - - /// - /// Remove all model directories - /// - public void RemoveAllDirectories(); - - // Model Files - /// - /// Get all of the model files that are available to be loaded - /// - /// - public IEnumerable ModelFileList { get; } - - /// - /// Only get the models associated with a specific directory - /// - /// - /// The files, if any associated with a given directory - public IEnumerable GetAvailableModelsFromDirectory(string directory); - - /// - /// Get the file data for given model - /// - /// - /// - /// If a model with the given file name is present - public bool TryGetModelFileMetadata(string fileName, out ModelFileMetadata modelMeta); - - // Model Load and Unload - /// - /// Load a model file to be used for infernce - /// - /// - /// - /// - /// - /// The loaded model on success - public Task LoadModel(string modelPath, - Action? modelConfigurator = null!, - string modelId = "", - CancellationToken cancellationToken = default); - - /// - /// Unload and dispose of a model with the given id - /// - /// - /// - public bool UnloadModel(string modelId); - - /// - /// Unload all currently loaded models - /// - public void UnloadAllModels(); - - /// - /// Attempt to get a model that's expected to be loaded - /// - /// - /// - /// - public bool TryGetLoadedModel(string modeId, out LLamaWeights model); - - /// - /// Currently loaded models - /// - /// - public IEnumerable GetLoadedModels(); -} - -/// -public class ModelManager : IModelManager -{ - /// - /// Support model type files - /// - public static string[] ExpectedModelFileTypes = [ - ".gguf" - ]; - - // keys are directories, values are applicable models - private readonly Dictionary> _availableModels = []; - - // model id/alias, to loaded model - private readonly Dictionary _loadedModelCache = []; - - /// - /// Create a new model manager that seeds available models from the given directory list - /// - /// - public ModelManager(string[] directories) - { - GetModelsFromDirectories(directories); - } - - private void GetModelsFromDirectories(params string[] dirs) - { - foreach (var dir in dirs) - { - var fullDirectoryPath = Path.GetFullPath(dir); - - if (!Directory.Exists(fullDirectoryPath)) - { - Trace.TraceError($"Model directory '{fullDirectoryPath}' does not exist"); - continue; - } - - if (_availableModels.ContainsKey(fullDirectoryPath)) - { - Trace.TraceWarning($"Model directory '{fullDirectoryPath}' already probed"); - continue; - } - - // find models in current dir that are of expected type - List directoryModelFiles = []; - foreach (var file in Directory.EnumerateFiles(fullDirectoryPath)) - { - if (!ExpectedModelFileTypes.Contains(Path.GetExtension(file))) - { - continue; - } - - // expected model file - var fi = new FileInfo(file); - directoryModelFiles.Add(new ModelFileMetadata - { - FileName = fi.Name, - FilePath = fi.FullName, - ModelType = ModelFileType.GGUF, - SizeInBytes = fi.Length, - }); - } - - _availableModels.Add(fullDirectoryPath, directoryModelFiles); - } - } - - /// - public IEnumerable ModelFileList - => _availableModels.SelectMany(x => x.Value); - /// - public IEnumerable ModelDirectories - => _availableModels.Keys; - - /// - public void AddDirectory(string directory) - { - GetModelsFromDirectories(directory); - } - - /// - public bool RemoveDirectory(string directory) - { - return _availableModels.Remove(Path.GetFullPath(directory)); - } - - /// - public void RemoveAllDirectories() - { - _availableModels.Clear(); - } - - /// - public IEnumerable GetAvailableModelsFromDirectory(string directory) - { - var dirPath = Path.GetFullPath(directory); - return _availableModels.TryGetValue(dirPath, out var dirModels) - ? dirModels - : []; - } - - /// - public bool TryGetModelFileMetadata(string fileName, out ModelFileMetadata modelMeta) - { - var filePath = Path.GetFullPath(fileName); - modelMeta = ModelFileList.FirstOrDefault(f => f.FilePath == filePath)!; - return modelMeta != null; - } - - /// - public IEnumerable GetLoadedModels() - { - return _loadedModelCache.Values; - } - - /// - public bool TryGetLoadedModel(string modelId, out LLamaWeights model) - { - return _loadedModelCache.TryGetValue(modelId, out model!); - } - - /// - public async Task LoadModel(string modelPath, - Action? modelConfigurator = null!, - string modelId = "", - CancellationToken cancellationToken = default) - { - // Configure model params - var modelParams = new ModelParams(modelPath); - modelConfigurator ??= DefaultModelConfigurator; - modelConfigurator.Invoke(modelParams); - - // load and cache - var model = await LLamaWeights.LoadFromFileAsync(modelParams, cancellationToken); - if (string.IsNullOrWhiteSpace(modelId)) - { - modelId = model.ModelName; - } - _loadedModelCache.Add(modelId, model); - return model; - } - - /// - /// Updates the passed in model params with default values - /// - /// - public static void DefaultModelConfigurator(ModelParams modelParams) - { - // Reasonable defaults - modelParams.GpuLayerCount = 12; - modelParams.Seed = 1337u; - modelParams.ContextSize = 2048; - modelParams.Encoding = Encoding.UTF8; - - modelParams.UseMemoryLock = true; - modelParams.UseMemorymap = true; - - // auto detect - modelParams.Threads = null!; - modelParams.BatchThreads = null!; - } - - /// - public bool UnloadModel(string modelId) - { - if (TryGetLoadedModel(modelId, out var model)) - { - model.Dispose(); - return _loadedModelCache.Remove(modelId); - } - return false; - } - - /// - public void UnloadAllModels() - { - foreach (var model in _loadedModelCache.Values) - { - model.Dispose(); - } - _loadedModelCache.Clear(); - } -} From 6b2d71df9c55494f1061073198888eb8ae2c3a15 Mon Sep 17 00:00:00 2001 From: Patrick Hovsepian Date: Wed, 12 Jun 2024 12:48:32 -0700 Subject: [PATCH 04/11] organization (#4) Co-authored-by: pat_hov --- LLama/Model/IModelCache.cs | 100 +++++++++++++++++++++ LLama/Model/ModelCache.cs | 146 +++++++------------------------ LLama/Model/ModelFileMetadata.cs | 23 +++++ 3 files changed, 154 insertions(+), 115 deletions(-) create mode 100644 LLama/Model/IModelCache.cs create mode 100644 LLama/Model/ModelFileMetadata.cs diff --git a/LLama/Model/IModelCache.cs b/LLama/Model/IModelCache.cs new file mode 100644 index 000000000..7d70d3262 --- /dev/null +++ b/LLama/Model/IModelCache.cs @@ -0,0 +1,100 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using LLama.Common; + +namespace LLama.Model; + +/// +/// A class that helps organize and load local models +/// +public interface IModelCache : IDisposable +{ + // Model Directories + /// + /// Configured set of directories that are scanned to find local models + /// + /// + public IEnumerable ModelDirectories { get; } + + /// + /// Add a directory containing model files + /// + /// + public void AddDirectory(string directory); + + /// + /// Remove a directory from being scanned and having model files made available + /// + /// + /// + public bool RemoveDirectory(string directory); + + /// + /// Remove all model directories + /// + public void RemoveAllDirectories(); + + // Model Files + /// + /// Get all of the model files that are available to be loaded + /// + /// + public IEnumerable ModelFileList { get; } + + /// + /// Only get the models associated with a specific directory + /// + /// + /// The files, if any associated with a given directory + public IEnumerable GetAvailableModelsFromDirectory(string directory); + + /// + /// Get the file data for given model + /// + /// + /// + /// If a model with the given file name is present + public bool TryGetModelFileMetadata(string fileName, out ModelFileMetadata modelMeta); + + // Model Load and Unload + /// + /// Load a model file to be used for infernce + /// + /// + /// + /// + /// + /// The loaded model on success + public Task LoadModel(string modelPath, + Action? modelConfigurator = null!, + string modelId = "", + CancellationToken cancellationToken = default); + + /// + /// Unload and dispose of a model with the given id + /// + /// + /// + public bool UnloadModel(string modelId); + + /// + /// Unload all currently loaded models + /// + public void UnloadAllModels(); + + /// + /// Attempt to get a model that's expected to be loaded + /// + /// + /// + /// + public bool TryGetLoadedModel(string modeId, out LLamaWeights model); + + /// + /// Currently loaded models + /// + /// + public IEnumerable GetLoadedModels(); +} diff --git a/LLama/Model/ModelCache.cs b/LLama/Model/ModelCache.cs index 5cdcdde1c..5ddd21a0f 100644 --- a/LLama/Model/ModelCache.cs +++ b/LLama/Model/ModelCache.cs @@ -9,123 +9,11 @@ namespace LLama.Model; -/// -/// Types of supported model files -/// -public enum ModelFileType -{ - GGUF -} - -/// -/// Metadata about available models -/// -public class ModelFileMetadata -{ -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - public string FileName { get; init; } = string.Empty; - public string FilePath { get; init; } = string.Empty; - public ModelFileType ModelType { get; init; } - public long SizeInBytes { get; init; } = 0; -#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member -} - -/// -/// A class that helps organize and load local models -/// -public interface IModelCache -{ - // Model Directories - /// - /// Configured set of directories that are scanned to find local models - /// - /// - public IEnumerable ModelDirectories { get; } - - /// - /// Add a directory containing model files - /// - /// - public void AddDirectory(string directory); - - /// - /// Remove a directory from being scanned and having model files made available - /// - /// - /// - public bool RemoveDirectory(string directory); - - /// - /// Remove all model directories - /// - public void RemoveAllDirectories(); - - // Model Files - /// - /// Get all of the model files that are available to be loaded - /// - /// - public IEnumerable ModelFileList { get; } - - /// - /// Only get the models associated with a specific directory - /// - /// - /// The files, if any associated with a given directory - public IEnumerable GetAvailableModelsFromDirectory(string directory); - - /// - /// Get the file data for given model - /// - /// - /// - /// If a model with the given file name is present - public bool TryGetModelFileMetadata(string fileName, out ModelFileMetadata modelMeta); - - // Model Load and Unload - /// - /// Load a model file to be used for infernce - /// - /// - /// - /// - /// - /// The loaded model on success - public Task LoadModel(string modelPath, - Action? modelConfigurator = null!, - string modelId = "", - CancellationToken cancellationToken = default); - - /// - /// Unload and dispose of a model with the given id - /// - /// - /// - public bool UnloadModel(string modelId); - - /// - /// Unload all currently loaded models - /// - public void UnloadAllModels(); - - /// - /// Attempt to get a model that's expected to be loaded - /// - /// - /// - /// - public bool TryGetLoadedModel(string modeId, out LLamaWeights model); - - /// - /// Currently loaded models - /// - /// - public IEnumerable GetLoadedModels(); -} - /// public class ModelCache : IModelCache { + private bool _disposed = false; + /// /// Support model type files /// @@ -277,7 +165,7 @@ public async Task LoadModel(string modelPath, if (string.IsNullOrWhiteSpace(modelId)) { modelId = model.ModelName; - + // Check if cached again with alias // TODO: Consider the case where the alias is different but the underlying model file is the same if (TryGetLoadedModel(modelId, out loadedModel)) @@ -310,4 +198,32 @@ public void UnloadAllModels() } _loadedModelCache.Clear(); } + + #region Dispose + /// + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// Unload all models when called explicity via dispose + /// + /// Whether or not this call is made explicitly(true) or via GC + protected virtual void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + if (disposing) + { + UnloadAllModels(); + } + + _disposed = true; + } + #endregion } diff --git a/LLama/Model/ModelFileMetadata.cs b/LLama/Model/ModelFileMetadata.cs new file mode 100644 index 000000000..4b674d887 --- /dev/null +++ b/LLama/Model/ModelFileMetadata.cs @@ -0,0 +1,23 @@ +namespace LLama.Model; + +/// +/// Types of supported model files +/// +public enum ModelFileType +{ +#pragma warning disable CS1591 // Missing XML comment for publicly visible type or membe + GGUF +} + +/// +/// Metadata about available models +/// +public class ModelFileMetadata +{ + public string FileName { get; init; } = string.Empty; + public string FilePath { get; init; } = string.Empty; + public ModelFileType ModelType { get; init; } + public long SizeInBytes { get; init; } = 0; +#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member +} + From 852e35d0942c1f9e3ce6f53d240cc0780587578f Mon Sep 17 00:00:00 2001 From: Patrick Hovsepian Date: Wed, 12 Jun 2024 19:53:02 -0700 Subject: [PATCH 05/11] Mmwip disp (#5) * organization * disposable and ref counter --------- Co-authored-by: pat_hov --- LLama.Unittest/Model/ModelCacheTests.cs | 29 ++++++---- LLama/LLamaWeights.cs | 35 +++++++---- LLama/Model/IModelCache.cs | 10 +--- LLama/Model/ModelCache.cs | 77 ++++++++++++++----------- LLama/Native/SafeLlamaModelHandle.cs | 38 ++++++++---- 5 files changed, 113 insertions(+), 76 deletions(-) diff --git a/LLama.Unittest/Model/ModelCacheTests.cs b/LLama.Unittest/Model/ModelCacheTests.cs index 936eff1dd..375c2d6cf 100644 --- a/LLama.Unittest/Model/ModelCacheTests.cs +++ b/LLama.Unittest/Model/ModelCacheTests.cs @@ -1,7 +1,7 @@ using LLama.Common; using LLama.Model; -namespace LLama.Unittest; +namespace LLama.Unittest.Model; public class ModelManagerTests { @@ -108,14 +108,16 @@ public async void LoadModel_LoadsAndCaches() var modelToLoad = TestableModelManager.ModelFileList .First(f => f.FileName.Contains("llama-2-7b")); - var model = await TestableModelManager.LoadModel(modelToLoad.FilePath, null!); - - Assert.Single(TestableModelManager.GetLoadedModels()); - + var model = await TestableModelManager.LoadModel(modelToLoad.FilePath); var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel); Assert.True(isLoaded); - // unload + // unload the newly acquired model even though it was cached + Assert.True(TestableModelManager.UnloadModel(model.ModelName)); + //cachedModel.Dispose(); // this does effectively nothing + + // unload "original" + //model.Dispose(); Assert.True(TestableModelManager.UnloadModel(model.ModelName)); Assert.Throws(() => @@ -135,7 +137,6 @@ public async void LoadModel_AlreadyLoaded_ReturnsFromCache() var model = await TestableModelManager.LoadModel(modelToLoad.FilePath); Assert.NotNull(model); Assert.Equal("LLaMA v2", model.ModelName); - Assert.Single(TestableModelManager.GetLoadedModels()); var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel); Assert.True(isLoaded); Assert.NotNull(cachedModel); @@ -153,16 +154,20 @@ public async void TryGetLoadedModel_AlreadyDisposed_ReturnsFalse() { Assert.NotNull(model); Assert.Equal("LLaMA v2", model.ModelName); - Assert.Single(TestableModelManager.GetLoadedModels()); var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel); Assert.True(isLoaded); Assert.NotNull(cachedModel); Assert.Equal("LLaMA v2", cachedModel.ModelName); - } // end scope, dispose model - // Model is now disposed - var isDispoedLoaded = TestableModelManager.TryGetLoadedModel("LLaMA v2", out var disposedModel); - Assert.False(isDispoedLoaded); + // unload from the last check + Assert.True(TestableModelManager.UnloadModel("LLaMA v2")); + + } // end scope, dispose is called on the model but since we have the model cache it should stick around until unloaded + Assert.True(TestableModelManager.UnloadModel("LLaMA v2")); + + // Model is still loaded due to cache + var isDisposedLoaded = TestableModelManager.TryGetLoadedModel("LLaMA v2", out var disposedModel); + Assert.False(isDisposedLoaded); Assert.Null(disposedModel); } } diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 809ca1208..250b82de4 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -71,6 +71,23 @@ private LLamaWeights(SafeLlamaModelHandle weights) Metadata = weights.ReadMetadata(); } + /// + /// Create from a "shared" handle + /// + /// + /// + public static LLamaWeights FromSafeModelHandle(SafeLlamaModelHandle handle) + { + var model = new LLamaWeights(handle); + + // Increment the model reference count while this weight exists. + // DangerousAddRef throws if it fails, so there is no need to check "success" + var success = false; + handle.DangerousAddRef(ref success); + + return model; + } + /// /// Load weights into memory /// @@ -79,19 +96,19 @@ private LLamaWeights(SafeLlamaModelHandle weights) public static LLamaWeights LoadFromFile(IModelParams @params) { using var pin = @params.ToLlamaModelParams(out var lparams); - var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); + var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); foreach (var adapter in @params.LoraAdapters) { - if (string.IsNullOrEmpty(adapter.Path)) - continue; - if (adapter.Scale <= 0) + if (string.IsNullOrEmpty(adapter.Path) || adapter.Scale <= 0) + { continue; + } - weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase); + model.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase); } - return new LLamaWeights(weights); + return new LLamaWeights(model); } /// @@ -133,11 +150,7 @@ public static async Task LoadFromFileAsync(IModelParams @params, C if (internalCallback != null && !internalCallback(progress, ctx)) return false; - // Check the cancellation token - if (token.IsCancellationRequested) - return false; - - return true; + return token.IsCancellationRequested; }; } #endif diff --git a/LLama/Model/IModelCache.cs b/LLama/Model/IModelCache.cs index 7d70d3262..7c5ec84b4 100644 --- a/LLama/Model/IModelCache.cs +++ b/LLama/Model/IModelCache.cs @@ -60,7 +60,8 @@ public interface IModelCache : IDisposable // Model Load and Unload /// - /// Load a model file to be used for infernce + /// Load a model file to be used for inference + /// The caller assumes responsible for disposing this model /// /// /// @@ -86,15 +87,10 @@ public Task LoadModel(string modelPath, /// /// Attempt to get a model that's expected to be loaded + /// The callers assumes responsiblilty for the lifetime of the model at this point if it exists in the cache /// /// /// /// public bool TryGetLoadedModel(string modeId, out LLamaWeights model); - - /// - /// Currently loaded models - /// - /// - public IEnumerable GetLoadedModels(); } diff --git a/LLama/Model/ModelCache.cs b/LLama/Model/ModelCache.cs index 5ddd21a0f..9c743eb4e 100644 --- a/LLama/Model/ModelCache.cs +++ b/LLama/Model/ModelCache.cs @@ -6,6 +6,7 @@ using System.Threading; using System.Threading.Tasks; using LLama.Common; +using LLama.Native; namespace LLama.Model; @@ -25,7 +26,7 @@ public class ModelCache : IModelCache private readonly Dictionary> _availableModels = []; // model id/alias, to loaded model - private readonly Dictionary _loadedModelCache = []; + private readonly Dictionary _loadedModelCache = []; /// /// Create a new model manager that seeds available models from the given directory list @@ -36,6 +37,15 @@ public ModelCache(string[] directories) GetModelsFromDirectories(directories); } + /// + public IEnumerable ModelFileList + => _availableModels.SelectMany(x => x.Value); + + /// + public IEnumerable ModelDirectories + => _availableModels.Keys; + + #region Directories private void GetModelsFromDirectories(params string[] dirs) { foreach (var dir in dirs) @@ -78,13 +88,6 @@ private void GetModelsFromDirectories(params string[] dirs) } } - /// - public IEnumerable ModelFileList - => _availableModels.SelectMany(x => x.Value); - /// - public IEnumerable ModelDirectories - => _availableModels.Keys; - /// public void AddDirectory(string directory) { @@ -111,6 +114,7 @@ public IEnumerable GetAvailableModelsFromDirectory(string dir ? dirModels : []; } + #endregion Directories /// public bool TryGetModelFileMetadata(string fileName, out ModelFileMetadata modelMeta) @@ -120,25 +124,13 @@ public bool TryGetModelFileMetadata(string fileName, out ModelFileMetadata model return modelMeta != null; } - /// - public IEnumerable GetLoadedModels() - { - return _loadedModelCache.Values; - } - /// public bool TryGetLoadedModel(string modelId, out LLamaWeights model) { - var isCached = _loadedModelCache.TryGetValue(modelId, out model!); - - // Externall disposed, act like it's not in here - if (isCached && model.NativeHandle.IsClosed) - { - _ = _loadedModelCache.Remove(modelId); - isCached = false; - model = null!; - } - + var isCached = _loadedModelCache.TryGetValue(modelId, out var handle); + model = isCached + ? LLamaWeights.FromSafeModelHandle(handle) + : null!; return isCached; } @@ -152,7 +144,6 @@ public async Task LoadModel(string modelPath, if (!string.IsNullOrEmpty(modelId) && TryGetLoadedModel(modelId, out var loadedModel)) { - Trace.TraceWarning($"Model {modelId} already loaded"); return loadedModel; } @@ -162,29 +153,44 @@ public async Task LoadModel(string modelPath, // load and cache var model = await LLamaWeights.LoadFromFileAsync(modelParams, cancellationToken); + + // Check if it's already cached, if so use that and dispose of this + // TODO: Consider the case where the alias is different but the underlying model file is the same if (string.IsNullOrWhiteSpace(modelId)) { modelId = model.ModelName; - // Check if cached again with alias - // TODO: Consider the case where the alias is different but the underlying model file is the same if (TryGetLoadedModel(modelId, out loadedModel)) { model.Dispose(); return loadedModel; } } - _loadedModelCache.Add(modelId, model); + + // Increment the model reference count while this model exists (newly created) + // DangerousAddRef throws if it fails, so there is no need to check "success" + // Do this here since we're passing this to the caller to own and it's not done as part of the normal weight creation + var refSuccess = false; + model.NativeHandle.DangerousAddRef(ref refSuccess); + + _loadedModelCache.Add(modelId, model.NativeHandle); return model; } + #region Unload /// public bool UnloadModel(string modelId) { - if (TryGetLoadedModel(modelId, out var model)) + if (_loadedModelCache.TryGetValue(modelId, out var handle)) { - model.Dispose(); - return _loadedModelCache.Remove(modelId); + // Decrement refcount on model + handle.DangerousRelease(); + handle.Dispose(); + if (handle.IsClosed || handle.IsInvalid) + { + return _loadedModelCache.Remove(modelId); + } + return true; } return false; } @@ -192,13 +198,16 @@ public bool UnloadModel(string modelId) /// public void UnloadAllModels() { - foreach (var model in _loadedModelCache.Values) + foreach (var handle in _loadedModelCache.Values) { - model.Dispose(); + handle.DangerousRelease(); + handle.Dispose(); } _loadedModelCache.Clear(); } + #endregion + #region Dispose /// public void Dispose() @@ -208,7 +217,7 @@ public void Dispose() } /// - /// Unload all models when called explicity via dispose + /// Unload all models when called explicitly via dispose /// /// Whether or not this call is made explicitly(true) or via GC protected virtual void Dispose(bool disposing) diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 1597908e3..5e66729da 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -16,6 +16,7 @@ namespace LLama.Native public sealed class SafeLlamaModelHandle : SafeLLamaHandleBase { + #region Properties /// /// Total number of tokens in vocabulary of this model /// @@ -61,6 +62,7 @@ public sealed class SafeLlamaModelHandle /// public int LayerCount => llama_n_embd(this); + private string _modelDescription = null!; /// /// Get a description of this model /// @@ -68,17 +70,22 @@ public string Description { get { - unsafe + if (_modelDescription is null) { - // Get description length - var size = llama_model_desc(this, null, 0); - var buf = new byte[size + 1]; - fixed (byte* bufPtr = buf) + unsafe { - size = llama_model_desc(this, bufPtr, buf.Length); - return Encoding.UTF8.GetString(buf, 0, size); + // Get description length + var size = llama_model_desc(this, null, 0); + var buf = new byte[size + 1]; + fixed (byte* bufPtr = buf) + { + size = llama_model_desc(this, bufPtr, buf.Length); + _modelDescription = Encoding.UTF8.GetString(buf, 0, size) ?? string.Empty; + } } } + + return _modelDescription; } } @@ -94,6 +101,7 @@ public string Description /// Get the special tokens of this model /// public ModelTokens Tokens => _tokens ??= new ModelTokens(this); + #endregion /// protected override bool ReleaseHandle() @@ -102,6 +110,7 @@ protected override bool ReleaseHandle() return true; } + // TODO: Move this to the model manager? /// /// Load a model from the given file path into memory /// @@ -116,12 +125,18 @@ public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelPara // - File is readable (explicit check) // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases. using (var fs = new FileStream(modelPath, FileMode.Open)) + { if (!fs.CanRead) + { throw new InvalidOperationException($"Model file '{modelPath}' is not readable"); + } + } var handle = llama_load_model_from_file(modelPath, lparams); if (handle.IsInvalid) + { throw new LoadWeightsFailedException(modelPath); + } return handle; } @@ -244,7 +259,6 @@ private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string k static extern unsafe int llama_model_meta_val_str_native(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size); } - /// /// Get the number of tokens in the model vocabulary /// @@ -545,7 +559,7 @@ public SafeLLamaContextHandle CreateContext(LLamaContextParams @params) keyLength = llama_model_meta_val_str(this, key, buffer); Debug.Assert(keyLength >= 0); - return buffer.AsMemory().Slice(0,keyLength); + return buffer.AsMemory().Slice(0, keyLength); } /// @@ -632,12 +646,12 @@ internal ModelTokens(SafeLlamaModelHandle model) const int buffSize = 32; Span buff = stackalloc byte[buffSize]; var tokenLength = _model.TokenToSpan(token ?? LLamaToken.InvalidToken, buff, special: isSpecialToken); - + if (tokenLength <= 0) { return null; } - + // if the original buffer wasn't large enough, create a new one if (tokenLength > buffSize) { @@ -663,7 +677,7 @@ internal ModelTokens(SafeLlamaModelHandle model) /// Get the End of Sentence token for this model /// public LLamaToken? EOS => Normalize(llama_token_eos(_model)); - + /// /// The textual representation of the end of speech special token for this model /// From a1155de6e76c4ec183809adc303d3ecd7d79d4d1 Mon Sep 17 00:00:00 2001 From: Patrick Hovsepian Date: Fri, 14 Jun 2024 10:35:24 -0700 Subject: [PATCH 06/11] Separate Interfaces (#6) * organization * disposable and ref counter * separate concerns a bit more * check * tweak --------- Co-authored-by: pat_hov --- LLama.Unittest/LLama.Unittest.csproj | 2 - .../Model/FileSystemModelRepoTests.cs | 104 +++++++++++++++ LLama.Unittest/Model/ModelCacheTests.cs | 118 +++-------------- LLama/LLamaWeights.cs | 69 +++++++--- LLama/Model/FileSystemModelRepo.cs | 119 ++++++++++++++++++ LLama/Model/HuggingFaceModelRepo.cs | 57 +++++++++ LLama/Model/IModelCache.cs | 67 ++-------- LLama/Model/IModelSourceRepo.cs | 58 +++++++++ LLama/Model/ModelCache.cs | 111 +--------------- LLama/Model/ModelFileMetadata.cs | 10 +- 10 files changed, 421 insertions(+), 294 deletions(-) create mode 100644 LLama.Unittest/Model/FileSystemModelRepoTests.cs create mode 100644 LLama/Model/FileSystemModelRepo.cs create mode 100644 LLama/Model/HuggingFaceModelRepo.cs create mode 100644 LLama/Model/IModelSourceRepo.cs diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj index 5c29a8513..087b12903 100644 --- a/LLama.Unittest/LLama.Unittest.csproj +++ b/LLama.Unittest/LLama.Unittest.csproj @@ -32,8 +32,6 @@ - - diff --git a/LLama.Unittest/Model/FileSystemModelRepoTests.cs b/LLama.Unittest/Model/FileSystemModelRepoTests.cs new file mode 100644 index 000000000..c867ae623 --- /dev/null +++ b/LLama.Unittest/Model/FileSystemModelRepoTests.cs @@ -0,0 +1,104 @@ +using LLama.Model; + +namespace LLama.Unittest.Model; + +public class FileSystemModelRepoTests +{ + private readonly FileSystemModelRepo TestableRepo; + + public FileSystemModelRepoTests() + { + TestableRepo = new([Constants.ModelDirectory]); + } + + [Fact] + public void ModelDirectories_IsCorrect() + { + var dirs = TestableRepo.ListSources(); + Assert.Single(dirs); + + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + } + + [Fact] + public void AddDirectory_DoesntDuplicate() + { + for (var i = 0; i < 10; i++) + { + TestableRepo.AddSource(Constants.ModelDirectory); + TestableRepo.AddSource(Path.GetFullPath(Constants.ModelDirectory)); + + var dirs = TestableRepo.ListSources(); + Assert.Single(dirs); + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + } + } + + [Fact] + public void RemoveDirectory() + { + var dirs = TestableRepo.ListSources(); + Assert.Single(dirs); + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + + Assert.True(TestableRepo.RemoveSource(Constants.ModelDirectory)); + Assert.Empty(TestableRepo.ListSources()); + Assert.Empty(TestableRepo.GetAvailableModels()); + } + + [Fact] + public void RemoveDirectory_DoesNotExist() + { + var dirs = TestableRepo.ListSources(); + Assert.Single(dirs); + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + + Assert.False(TestableRepo.RemoveSource("foo/boo/bar")); + Assert.Single(dirs); + } + + [Fact] + public void RemoveAllDirectories() + { + var dirs = TestableRepo.ListSources(); + Assert.Single(dirs); + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + + TestableRepo.RemoveAllSources(); + Assert.Empty(TestableRepo.ListSources()); + Assert.Empty(TestableRepo.GetAvailableModels()); + } + + [Fact] + public void ModelFiles_IsCorrect() + { + var files = TestableRepo.GetAvailableModels(); + Assert.Equal(4, files.Count()); + } + + [Fact] + public void GetAvailableModelsFromDirectory() + { + var files = TestableRepo.GetAvailableModelsFromSource(Constants.ModelDirectory); + Assert.Equal(4, files.Count()); + + files = TestableRepo.GetAvailableModels(); + Assert.Equal(4, files.Count()); + } + + [Fact] + public void TryGetModelFileMetadata_WhenExists() + { + var expectedFile = TestableRepo.GetAvailableModels().First(); + var found = TestableRepo.TryGetModelFileMetadata(expectedFile.ModelFileUri, out var foundData); + + Assert.True(found); + Assert.Equal(expectedFile.ModelFileUri, foundData.ModelFileUri); + } + +} diff --git a/LLama.Unittest/Model/ModelCacheTests.cs b/LLama.Unittest/Model/ModelCacheTests.cs index 375c2d6cf..62d8e17ef 100644 --- a/LLama.Unittest/Model/ModelCacheTests.cs +++ b/LLama.Unittest/Model/ModelCacheTests.cs @@ -5,110 +5,22 @@ namespace LLama.Unittest.Model; public class ModelManagerTests { + private readonly IModelSourceRepo _testRepo = new FileSystemModelRepo([Constants.ModelDirectory]); + private readonly ModelCache TestableModelManager; public ModelManagerTests() { - TestableModelManager = new([Constants.ModelDirectory]); - } - - [Fact] - public void ModelDirectories_IsCorrect() - { - var dirs = TestableModelManager.ModelDirectories; - Assert.Single(dirs); - - var expected = dirs.First()!.Contains(Constants.ModelDirectory); - Assert.True(expected); - } - - [Fact] - public void AddDirectory_DoesntDuplicate() - { - for (var i = 0; i < 10; i++) - { - TestableModelManager.AddDirectory(Constants.ModelDirectory); - TestableModelManager.AddDirectory(Path.GetFullPath(Constants.ModelDirectory)); - - var dirs = TestableModelManager.ModelDirectories; - Assert.Single(dirs); - var expected = dirs.First()!.Contains(Constants.ModelDirectory); - Assert.True(expected); - } - } - - [Fact] - public void RemoveDirectory() - { - var dirs = TestableModelManager.ModelDirectories; - Assert.Single(dirs); - var expected = dirs.First()!.Contains(Constants.ModelDirectory); - Assert.True(expected); - - Assert.True(TestableModelManager.RemoveDirectory(Constants.ModelDirectory)); - Assert.Empty(TestableModelManager.ModelDirectories); - Assert.Empty(TestableModelManager.ModelFileList); - } - - [Fact] - public void RemoveDirectory_DoesNotExist() - { - var dirs = TestableModelManager.ModelDirectories; - Assert.Single(dirs); - var expected = dirs.First()!.Contains(Constants.ModelDirectory); - Assert.True(expected); - - Assert.False(TestableModelManager.RemoveDirectory("foo/boo/bar")); - Assert.Single(dirs); - } - - [Fact] - public void RemoveAllDirectories() - { - var dirs = TestableModelManager.ModelDirectories; - Assert.Single(dirs); - var expected = dirs.First()!.Contains(Constants.ModelDirectory); - Assert.True(expected); - - TestableModelManager.RemoveAllDirectories(); - Assert.Empty(TestableModelManager.ModelDirectories); - Assert.Empty(TestableModelManager.ModelFileList); - } - - [Fact] - public void ModelFiles_IsCorrect() - { - var files = TestableModelManager.ModelFileList; - Assert.Equal(4, files.Count()); - } - - [Fact] - public void GetAvailableModelsFromDirectory() - { - var files = TestableModelManager.GetAvailableModelsFromDirectory(Constants.ModelDirectory); - Assert.Equal(4, files.Count()); - - files = TestableModelManager.ModelFileList; - Assert.Equal(4, files.Count()); - } - - [Fact] - public void TryGetModelFileMetadata_WhenExists() - { - var expectedFile = TestableModelManager.ModelFileList.First(); - var found = TestableModelManager.TryGetModelFileMetadata(expectedFile.FilePath, out var foundData); - - Assert.True(found); - Assert.Equal(expectedFile.FilePath, foundData.FilePath); + TestableModelManager = new(); } [Fact] public async void LoadModel_LoadsAndCaches() { - var modelToLoad = TestableModelManager.ModelFileList - .First(f => f.FileName.Contains("llama-2-7b")); + var modelToLoad = _testRepo.GetAvailableModels() + .First(f => f.ModelFileName.Contains("llama-2-7b")); - var model = await TestableModelManager.LoadModel(modelToLoad.FilePath); + var model = await TestableModelManager.LoadModelAsync(modelToLoad); var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel); Assert.True(isLoaded); @@ -117,24 +29,26 @@ public async void LoadModel_LoadsAndCaches() //cachedModel.Dispose(); // this does effectively nothing // unload "original" - //model.Dispose(); + model.Dispose(); // need to explicitly dispose the model that the caller (us) owns Assert.True(TestableModelManager.UnloadModel(model.ModelName)); + Assert.False(TestableModelManager.UnloadModel(model.ModelName)); + Assert.Throws(() => { - _ = model.CreateContext(new ModelParams(modelToLoad.FilePath)); + _ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri)); }); } [Fact] public async void LoadModel_AlreadyLoaded_ReturnsFromCache() { - var modelToLoad = TestableModelManager.ModelFileList - .First(f => f.FileName.Contains("llama-2-7b")); + var modelToLoad = _testRepo.GetAvailableModels() + .First(f => f.ModelFileName.Contains("llama-2-7b")); for (var i = 0; i < 5; i++) { - var model = await TestableModelManager.LoadModel(modelToLoad.FilePath); + var model = await TestableModelManager.LoadModelAsync(modelToLoad); Assert.NotNull(model); Assert.Equal("LLaMA v2", model.ModelName); var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel); @@ -147,10 +61,10 @@ public async void LoadModel_AlreadyLoaded_ReturnsFromCache() [Fact] public async void TryGetLoadedModel_AlreadyDisposed_ReturnsFalse() { - var modelToLoad = TestableModelManager.ModelFileList - .First(f => f.FileName.Contains("llama-2-7b")); + var modelToLoad = _testRepo.GetAvailableModels() + .First(f => f.ModelFileName.Contains("llama-2-7b")); - using (var model = await TestableModelManager.LoadModel(modelToLoad.FilePath)) + using (var model = await TestableModelManager.LoadModelAsync(modelToLoad)) { Assert.NotNull(model); Assert.Equal("LLaMA v2", model.ModelName); diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 250b82de4..10d1a92e5 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -16,12 +16,21 @@ namespace LLama public sealed class LLamaWeights : IDisposable { + private bool _disposed = false; + + /// + ~LLamaWeights() + { + Dispose(false); + } + /// /// The native handle, which is used in the native APIs /// /// Be careful how you use this! public SafeLlamaModelHandle NativeHandle { get; } + #region Properties /// /// The models name as specified in it's metadata /// @@ -64,28 +73,28 @@ public sealed class LLamaWeights /// All metadata keys in this model /// public IReadOnlyDictionary Metadata { get; set; } + #endregion - private LLamaWeights(SafeLlamaModelHandle weights) + private LLamaWeights(SafeLlamaModelHandle handle) { - NativeHandle = weights; - Metadata = weights.ReadMetadata(); + NativeHandle = handle; + Metadata = handle.ReadMetadata(); + + // Increment the model reference count while this weight exists. + // DangerousAddRef throws if it fails, so there is no need to check "success" + var success = false; + NativeHandle.DangerousAddRef(ref success); } + #region Load /// - /// Create from a "shared" handle + /// Create from a "shared" handle. The `SafeLlamaModelHandle` will not be disposed and the model will not be unloaded until all such handles have been disposed. /// /// /// public static LLamaWeights FromSafeModelHandle(SafeLlamaModelHandle handle) { - var model = new LLamaWeights(handle); - - // Increment the model reference count while this weight exists. - // DangerousAddRef throws if it fails, so there is no need to check "success" - var success = false; - handle.DangerousAddRef(ref success); - - return model; + return new LLamaWeights(handle); } /// @@ -128,15 +137,15 @@ public static async Task LoadFromFileAsync(IModelParams @params, C var loraBase = @params.LoraBase; var loraAdapters = @params.LoraAdapters.ToArray(); - // Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a - // slightly smaller range to allow some space for reporting LoRA loading too. - var modelLoadProgressRange = 1f; - if (loraAdapters.Length > 0) - modelLoadProgressRange = 0.9f; - using (@params.ToLlamaModelParams(out var lparams)) { #if !NETSTANDARD2_0 + // Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a + // slightly smaller range to allow some space for reporting LoRA loading too. + var modelLoadProgressRange = 1f; + if (loraAdapters.Length > 0) + modelLoadProgressRange = 0.9f; + // Overwrite the progress callback with one which polls the cancellation token and updates the progress object if (token.CanBeCanceled || progressReporter != null) { @@ -204,11 +213,33 @@ public static async Task LoadFromFileAsync(IModelParams @params, C return model; } } + #endregion /// public void Dispose() { - NativeHandle.Dispose(); + Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// Unload all models when called explicitly via dispose + /// + /// Whether or not this call is made explicitly(true) or via GC + internal void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + if (disposing) + { + NativeHandle.DangerousRelease(); + NativeHandle.Dispose(); + } + + _disposed = true; } /// diff --git a/LLama/Model/FileSystemModelRepo.cs b/LLama/Model/FileSystemModelRepo.cs new file mode 100644 index 000000000..70987cb4b --- /dev/null +++ b/LLama/Model/FileSystemModelRepo.cs @@ -0,0 +1,119 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; + +namespace LLama.Model; + +/// +/// A model repository that uses a file system to search for available models +/// +public class FileSystemModelRepo : IModelSourceRepo +{ + /// + /// Support model type files + /// + public static readonly string[] ExpectedModelFileTypes = [ + ".gguf" + ]; + + // keys are directories, values are applicable models + private readonly Dictionary> _availableModels = []; + + /// + /// Create a model repo that scans the filesystem to find models + /// + /// + public FileSystemModelRepo(string[] directories) + { + GetModelsFromDirectories(directories); + } + + #region Sources + /// + public IEnumerable ListSources() => _availableModels.Keys; + + private void GetModelsFromDirectories(params string[] dirs) + { + foreach (var dir in dirs) + { + var fullDirectoryPath = Path.GetFullPath(dir); + + if (!Directory.Exists(fullDirectoryPath)) + { + Trace.TraceError($"Model directory '{fullDirectoryPath}' does not exist"); + continue; + } + + if (_availableModels.ContainsKey(fullDirectoryPath)) + { + Trace.TraceWarning($"Model directory '{fullDirectoryPath}' already probed"); + continue; + } + + // find models in current dir that are of expected type + List directoryModelFiles = []; + foreach (var file in Directory.EnumerateFiles(fullDirectoryPath)) + { + if (!ExpectedModelFileTypes.Contains(Path.GetExtension(file))) + { + continue; + } + + // expected model file + // TODO: handle symbolic links + var fi = new FileInfo(file); + directoryModelFiles.Add(new ModelFileMetadata + { + ModelFileName = fi.Name, + ModelFileUri = fi.FullName, + ModelType = ModelFileType.GGUF, + ModelFileSizeInBytes = fi.Length, + }); + } + + _availableModels.Add(fullDirectoryPath, directoryModelFiles); + } + } + + /// + public void AddSource(string directory) + { + GetModelsFromDirectories(directory); + } + + /// + public bool RemoveSource(string directory) + { + return _availableModels.Remove(Path.GetFullPath(directory)); + } + + /// + public void RemoveAllSources() + { + _availableModels.Clear(); + } + #endregion Sources + + /// + public IEnumerable GetAvailableModels() + => _availableModels.SelectMany(x => x.Value); + + /// + public IEnumerable GetAvailableModelsFromSource(string directory) + { + var dirPath = Path.GetFullPath(directory); + return _availableModels.TryGetValue(dirPath, out var dirModels) + ? dirModels + : []; + } + + /// + public bool TryGetModelFileMetadata(string modelFileName, out ModelFileMetadata modelMeta) + { + var filePath = Path.GetFullPath(modelFileName); + modelMeta = GetAvailableModels().FirstOrDefault(f => f.ModelFileUri == filePath)!; + return modelMeta != null; + } +} diff --git a/LLama/Model/HuggingFaceModelRepo.cs b/LLama/Model/HuggingFaceModelRepo.cs new file mode 100644 index 000000000..23a99408b --- /dev/null +++ b/LLama/Model/HuggingFaceModelRepo.cs @@ -0,0 +1,57 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net.Http; +using Microsoft.Extensions.Logging; + +namespace LLama.Model; + +// This is for demo purposes - it can be finalized later +internal class HuggingFaceModelRepo(ILogger logger, + HttpClient hfClient) : IModelSourceRepo +{ + private readonly ILogger _logger = logger; + private readonly HttpClient _hfClient = hfClient; + + // https://huggingface.co/leliuga/all-MiniLM-L12-v2-GGUF/resolve/main/all-MiniLM-L12-v2.Q8_0.gguf + private readonly HashSet _hfModelUri = []; + + public void AddSource(string source) + { + if (!Uri.IsWellFormedUriString(source, UriKind.Absolute)) + { + Trace.TraceWarning("URI is not a valid HuggingFace URL"); + } + + // TODO: call HF to check model exists + // TODO: Get metadta about model an + _hfModelUri.Add(source); + } + + public IEnumerable ListSources() => _hfModelUri; + + public void RemoveAllSources() + { + _hfModelUri.Clear(); + } + + public bool RemoveSource(string source) + { + return _hfModelUri.Remove(source); + } + + public bool TryGetModelFileMetadata(string modelFileName, out ModelFileMetadata modelMeta) + { + throw new NotImplementedException(); + } + + public IEnumerable GetAvailableModels() + { + throw new NotImplementedException(); + } + + public IEnumerable GetAvailableModelsFromSource(string source) + { + throw new NotImplementedException(); + } +} diff --git a/LLama/Model/IModelCache.cs b/LLama/Model/IModelCache.cs index 7c5ec84b4..abf6eacb3 100644 --- a/LLama/Model/IModelCache.cs +++ b/LLama/Model/IModelCache.cs @@ -11,64 +11,22 @@ namespace LLama.Model; /// public interface IModelCache : IDisposable { - // Model Directories /// - /// Configured set of directories that are scanned to find local models + /// The current number of file handles in cache. /// - /// - public IEnumerable ModelDirectories { get; } + /// Number of cached models + public int ModelsCached(); /// - /// Add a directory containing model files + /// Load a model file to be used for inference. + /// The caller assumes responsibility for disposing this model and MUST call Unload /// - /// - public void AddDirectory(string directory); - - /// - /// Remove a directory from being scanned and having model files made available - /// - /// - /// - public bool RemoveDirectory(string directory); - - /// - /// Remove all model directories - /// - public void RemoveAllDirectories(); - - // Model Files - /// - /// Get all of the model files that are available to be loaded - /// - /// - public IEnumerable ModelFileList { get; } - - /// - /// Only get the models associated with a specific directory - /// - /// - /// The files, if any associated with a given directory - public IEnumerable GetAvailableModelsFromDirectory(string directory); - - /// - /// Get the file data for given model - /// - /// - /// - /// If a model with the given file name is present - public bool TryGetModelFileMetadata(string fileName, out ModelFileMetadata modelMeta); - - // Model Load and Unload - /// - /// Load a model file to be used for inference - /// The caller assumes responsible for disposing this model - /// - /// + /// /// - /// + /// An alias to uniquely identify this model's underyling handle. If none is supplied, the model's name is used.' /// /// The loaded model on success - public Task LoadModel(string modelPath, + public Task LoadModelAsync(ModelFileMetadata metadata, Action? modelConfigurator = null!, string modelId = "", CancellationToken cancellationToken = default); @@ -84,13 +42,4 @@ public Task LoadModel(string modelPath, /// Unload all currently loaded models /// public void UnloadAllModels(); - - /// - /// Attempt to get a model that's expected to be loaded - /// The callers assumes responsiblilty for the lifetime of the model at this point if it exists in the cache - /// - /// - /// - /// - public bool TryGetLoadedModel(string modeId, out LLamaWeights model); } diff --git a/LLama/Model/IModelSourceRepo.cs b/LLama/Model/IModelSourceRepo.cs new file mode 100644 index 000000000..c502cbe8c --- /dev/null +++ b/LLama/Model/IModelSourceRepo.cs @@ -0,0 +1,58 @@ +using System.Collections.Generic; + +namespace LLama.Model; + +/// +/// A source for models +/// +public interface IModelSourceRepo +{ + #region Source + /// + /// Configured set of sources that are scanned to find models + /// + /// + public IEnumerable ListSources(); + + /// + /// Add a source containing one or more files + /// + /// + public void AddSource(string source); + + /// + /// Remove a source from being scanned and having model files made available + /// + /// + /// + public bool RemoveSource(string source); + + /// + /// Remove all model directories + /// + public void RemoveAllSources(); + #endregion + + #region AvailableModels + /// + /// Get all of the model files that are available to be loaded + /// + /// + public IEnumerable GetAvailableModels(); + + /// + /// Only get the models associated with a specific source + /// + /// + /// The files, if any associated with a given source + public IEnumerable GetAvailableModelsFromSource(string source); + + /// + /// Get the file data for given model + /// + /// + /// + /// If a model with the given file name is present + public bool TryGetModelFileMetadata(string modelFileName, out ModelFileMetadata modelMeta); + #endregion +} diff --git a/LLama/Model/ModelCache.cs b/LLama/Model/ModelCache.cs index 9c743eb4e..444cc00a8 100644 --- a/LLama/Model/ModelCache.cs +++ b/LLama/Model/ModelCache.cs @@ -15,114 +15,12 @@ public class ModelCache : IModelCache { private bool _disposed = false; - /// - /// Support model type files - /// - public static readonly string[] ExpectedModelFileTypes = [ - ".gguf" - ]; - - // keys are directories, values are applicable models - private readonly Dictionary> _availableModels = []; - // model id/alias, to loaded model private readonly Dictionary _loadedModelCache = []; - /// - /// Create a new model manager that seeds available models from the given directory list - /// - /// - public ModelCache(string[] directories) - { - GetModelsFromDirectories(directories); - } - /// - public IEnumerable ModelFileList - => _availableModels.SelectMany(x => x.Value); - - /// - public IEnumerable ModelDirectories - => _availableModels.Keys; - - #region Directories - private void GetModelsFromDirectories(params string[] dirs) - { - foreach (var dir in dirs) - { - var fullDirectoryPath = Path.GetFullPath(dir); - - if (!Directory.Exists(fullDirectoryPath)) - { - Trace.TraceError($"Model directory '{fullDirectoryPath}' does not exist"); - continue; - } - - if (_availableModels.ContainsKey(fullDirectoryPath)) - { - Trace.TraceWarning($"Model directory '{fullDirectoryPath}' already probed"); - continue; - } - - // find models in current dir that are of expected type - List directoryModelFiles = []; - foreach (var file in Directory.EnumerateFiles(fullDirectoryPath)) - { - if (!ExpectedModelFileTypes.Contains(Path.GetExtension(file))) - { - continue; - } - - // expected model file - var fi = new FileInfo(file); - directoryModelFiles.Add(new ModelFileMetadata - { - FileName = fi.Name, - FilePath = fi.FullName, - ModelType = ModelFileType.GGUF, - SizeInBytes = fi.Length, - }); - } - - _availableModels.Add(fullDirectoryPath, directoryModelFiles); - } - } - - /// - public void AddDirectory(string directory) - { - GetModelsFromDirectories(directory); - } - - /// - public bool RemoveDirectory(string directory) - { - return _availableModels.Remove(Path.GetFullPath(directory)); - } - - /// - public void RemoveAllDirectories() - { - _availableModels.Clear(); - } - - /// - public IEnumerable GetAvailableModelsFromDirectory(string directory) - { - var dirPath = Path.GetFullPath(directory); - return _availableModels.TryGetValue(dirPath, out var dirModels) - ? dirModels - : []; - } - #endregion Directories - - /// - public bool TryGetModelFileMetadata(string fileName, out ModelFileMetadata modelMeta) - { - var filePath = Path.GetFullPath(fileName); - modelMeta = ModelFileList.FirstOrDefault(f => f.FilePath == filePath)!; - return modelMeta != null; - } + public int ModelsCached() + => _loadedModelCache.Count; /// public bool TryGetLoadedModel(string modelId, out LLamaWeights model) @@ -135,7 +33,7 @@ public bool TryGetLoadedModel(string modelId, out LLamaWeights model) } /// - public async Task LoadModel(string modelPath, + public async Task LoadModelAsync(ModelFileMetadata metadata, Action? modelConfigurator = null!, string modelId = "", CancellationToken cancellationToken = default) @@ -148,7 +46,7 @@ public async Task LoadModel(string modelPath, } // Configure model params - var modelParams = new ModelParams(modelPath); + var modelParams = new ModelParams(metadata.ModelFileUri); modelConfigurator?.Invoke(modelParams); // load and cache @@ -205,7 +103,6 @@ public void UnloadAllModels() } _loadedModelCache.Clear(); } - #endregion #region Dispose diff --git a/LLama/Model/ModelFileMetadata.cs b/LLama/Model/ModelFileMetadata.cs index 4b674d887..3b2c18145 100644 --- a/LLama/Model/ModelFileMetadata.cs +++ b/LLama/Model/ModelFileMetadata.cs @@ -1,11 +1,11 @@ namespace LLama.Model; +#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member /// /// Types of supported model files /// public enum ModelFileType { -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or membe GGUF } @@ -14,10 +14,10 @@ public enum ModelFileType /// public class ModelFileMetadata { - public string FileName { get; init; } = string.Empty; - public string FilePath { get; init; } = string.Empty; + public string ModelFileName { get; init; } = string.Empty; + public string ModelFileUri { get; init; } = string.Empty; public ModelFileType ModelType { get; init; } - public long SizeInBytes { get; init; } = 0; -#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member + public long ModelFileSizeInBytes { get; init; } = 0; } +#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member From 262a8ce17f6038aaa2d360558688eeef5d2bd5fc Mon Sep 17 00:00:00 2001 From: Patrick Hovsepian Date: Mon, 17 Jun 2024 09:23:56 -0700 Subject: [PATCH 07/11] Dispose and Clone Semantics (#7) * organization * disposable and ref counter * separate concerns a bit more * check * tweak * stash * note --------- Co-authored-by: pat_hov --- LLama.Unittest/Model/ModelCacheTests.cs | 52 +++++++++++++++----- LLama/LLamaWeights.cs | 63 ++++++++++++++----------- LLama/Model/HuggingFaceModelRepo.cs | 2 +- LLama/Model/IModelCache.cs | 2 +- LLama/Model/ModelCache.cs | 54 ++++++++++++--------- 5 files changed, 109 insertions(+), 64 deletions(-) diff --git a/LLama.Unittest/Model/ModelCacheTests.cs b/LLama.Unittest/Model/ModelCacheTests.cs index 62d8e17ef..80dd18deb 100644 --- a/LLama.Unittest/Model/ModelCacheTests.cs +++ b/LLama.Unittest/Model/ModelCacheTests.cs @@ -14,26 +14,52 @@ public ModelManagerTests() TestableModelManager = new(); } + [Fact] + public async void LoadModel_DisposesOnUnload() + { + var modelToLoad = _testRepo.GetAvailableModels() + .First(f => f.ModelFileName.Contains("llama-2-7b")); + + var model = await TestableModelManager.LoadModelAsync(modelToLoad); + Assert.NotNull(model); + + // unloaded and disposed` + Assert.True(TestableModelManager.UnloadModel(model.ModelName)); + Assert.Throws(() => + { + _ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri)); + }); + + // wont unload and already + Assert.False(TestableModelManager.UnloadModel(model.ModelName)); + Assert.Throws(() => + { + _ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri)); + }); + } + [Fact] public async void LoadModel_LoadsAndCaches() { var modelToLoad = _testRepo.GetAvailableModels() .First(f => f.ModelFileName.Contains("llama-2-7b")); + // Create Model -- Ref 1 var model = await TestableModelManager.LoadModelAsync(modelToLoad); - var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel); - Assert.True(isLoaded); + Assert.NotNull(model); + + // clone it -- Ref 2 + var isCachedAndCloned = TestableModelManager.TryCloneLoadedModel(model.ModelName, out var cachedModel); + Assert.True(isCachedAndCloned); + Assert.NotNull(cachedModel); - // unload the newly acquired model even though it was cached + cachedModel.Dispose(); //-- ref 1 Assert.True(TestableModelManager.UnloadModel(model.ModelName)); - //cachedModel.Dispose(); // this does effectively nothing - // unload "original" - model.Dispose(); // need to explicitly dispose the model that the caller (us) owns + // unloaded and disposed` -- ref 2 Assert.True(TestableModelManager.UnloadModel(model.ModelName)); Assert.False(TestableModelManager.UnloadModel(model.ModelName)); - Assert.Throws(() => { _ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri)); @@ -51,7 +77,7 @@ public async void LoadModel_AlreadyLoaded_ReturnsFromCache() var model = await TestableModelManager.LoadModelAsync(modelToLoad); Assert.NotNull(model); Assert.Equal("LLaMA v2", model.ModelName); - var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel); + var isLoaded = TestableModelManager.TryCloneLoadedModel(model.ModelName, out var cachedModel); Assert.True(isLoaded); Assert.NotNull(cachedModel); Assert.Equal("LLaMA v2", cachedModel.ModelName); @@ -67,20 +93,20 @@ public async void TryGetLoadedModel_AlreadyDisposed_ReturnsFalse() using (var model = await TestableModelManager.LoadModelAsync(modelToLoad)) { Assert.NotNull(model); - Assert.Equal("LLaMA v2", model.ModelName); - var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel); + Assert.Equal(model.ModelName, model.ModelName); + var isLoaded = TestableModelManager.TryCloneLoadedModel(model.ModelName, out var cachedModel); Assert.True(isLoaded); Assert.NotNull(cachedModel); - Assert.Equal("LLaMA v2", cachedModel.ModelName); + Assert.Equal(model.ModelName, cachedModel.ModelName); // unload from the last check - Assert.True(TestableModelManager.UnloadModel("LLaMA v2")); + Assert.True(TestableModelManager.UnloadModel(model.ModelName)); } // end scope, dispose is called on the model but since we have the model cache it should stick around until unloaded Assert.True(TestableModelManager.UnloadModel("LLaMA v2")); // Model is still loaded due to cache - var isDisposedLoaded = TestableModelManager.TryGetLoadedModel("LLaMA v2", out var disposedModel); + var isDisposedLoaded = TestableModelManager.TryCloneLoadedModel("LLaMA v2", out var disposedModel); Assert.False(isDisposedLoaded); Assert.Null(disposedModel); } diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 10d1a92e5..13cd31d34 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -18,12 +19,6 @@ public sealed class LLamaWeights { private bool _disposed = false; - /// - ~LLamaWeights() - { - Dispose(false); - } - /// /// The native handle, which is used in the native APIs /// @@ -86,15 +81,43 @@ private LLamaWeights(SafeLlamaModelHandle handle) NativeHandle.DangerousAddRef(ref success); } - #region Load /// - /// Create from a "shared" handle. The `SafeLlamaModelHandle` will not be disposed and the model will not be unloaded until all such handles have been disposed. + /// Create an instance of the model using the supplied handle and metadata. + /// Metadata will not be re-read from the handle. /// /// + /// + private LLamaWeights(SafeLlamaModelHandle handle, IReadOnlyDictionary metadata) + { + NativeHandle = handle; + Metadata = metadata; + + // Increment the model reference count while this weight exists. + // DangerousAddRef throws if it fails, so there is no need to check "success" + var success = false; + NativeHandle.DangerousAddRef(ref success); + } + + /// + ~LLamaWeights() + { + // Ensure the handle is released even if user's don't explicitly call Dispose + Dispose(); + } + + #region Load + /// + /// Create a new instance of the model using same NativeHandle as this model. + /// Metadata is also copied from the existing model rather than read from the hanlde directly + /// The `SafeLlamaModelHandle` will not be disposed and the model will not be unloaded until ALL such handles have been disposed. + /// /// - public static LLamaWeights FromSafeModelHandle(SafeLlamaModelHandle handle) + public LLamaWeights CloneFromHandleWithMetadata() { - return new LLamaWeights(handle); + var metadataClone = Metadata + .Select(x => x) + .ToDictionary(x => x.Key, x => x.Value); + return new LLamaWeights(NativeHandle, metadataClone); } /// @@ -218,28 +241,14 @@ public static async Task LoadFromFileAsync(IModelParams @params, C /// public void Dispose() { - Dispose(true); - GC.SuppressFinalize(this); - } - - /// - /// Unload all models when called explicitly via dispose - /// - /// Whether or not this call is made explicitly(true) or via GC - internal void Dispose(bool disposing) - { - if (_disposed) - { - return; - } - - if (disposing) + if (!_disposed) { NativeHandle.DangerousRelease(); NativeHandle.Dispose(); + _disposed = true; } - _disposed = true; + GC.SuppressFinalize(this); } /// diff --git a/LLama/Model/HuggingFaceModelRepo.cs b/LLama/Model/HuggingFaceModelRepo.cs index 23a99408b..072e87e66 100644 --- a/LLama/Model/HuggingFaceModelRepo.cs +++ b/LLama/Model/HuggingFaceModelRepo.cs @@ -24,7 +24,7 @@ public void AddSource(string source) } // TODO: call HF to check model exists - // TODO: Get metadta about model an + // TODO: Get metadata about model _hfModelUri.Add(source); } diff --git a/LLama/Model/IModelCache.cs b/LLama/Model/IModelCache.cs index abf6eacb3..df9e518fd 100644 --- a/LLama/Model/IModelCache.cs +++ b/LLama/Model/IModelCache.cs @@ -23,7 +23,7 @@ public interface IModelCache : IDisposable /// /// /// - /// An alias to uniquely identify this model's underyling handle. If none is supplied, the model's name is used.' + /// An alias to uniquely identify this model's underlying handle. If none is supplied, the model's name is used.' /// /// The loaded model on success public Task LoadModelAsync(ModelFileMetadata metadata, diff --git a/LLama/Model/ModelCache.cs b/LLama/Model/ModelCache.cs index 444cc00a8..1d7eb9435 100644 --- a/LLama/Model/ModelCache.cs +++ b/LLama/Model/ModelCache.cs @@ -10,25 +10,35 @@ namespace LLama.Model; +internal class CachedModelReference +{ + public LLamaWeights Model { get; init; } = null!; + public int RefCount { get; set; } = 0; +} + /// public class ModelCache : IModelCache { private bool _disposed = false; // model id/alias, to loaded model - private readonly Dictionary _loadedModelCache = []; + private readonly Dictionary _loadedModelCache = []; /// - public int ModelsCached() + public int ModelsCached() => _loadedModelCache.Count; /// - public bool TryGetLoadedModel(string modelId, out LLamaWeights model) + public bool TryCloneLoadedModel(string modelId, out LLamaWeights model) { - var isCached = _loadedModelCache.TryGetValue(modelId, out var handle); - model = isCached - ? LLamaWeights.FromSafeModelHandle(handle) - : null!; + var isCached = _loadedModelCache.TryGetValue(modelId, out var cachedModel); + + model = null!; + if (isCached) + { + model = cachedModel.Model.CloneFromHandleWithMetadata(); + cachedModel.RefCount++; + } return isCached; } @@ -40,7 +50,7 @@ public async Task LoadModelAsync(ModelFileMetadata metadata, { // is the model already loaded? alias could be different but it's up to the caller to be consistent if (!string.IsNullOrEmpty(modelId) - && TryGetLoadedModel(modelId, out var loadedModel)) + && TryCloneLoadedModel(modelId, out var loadedModel)) { return loadedModel; } @@ -58,20 +68,18 @@ public async Task LoadModelAsync(ModelFileMetadata metadata, { modelId = model.ModelName; - if (TryGetLoadedModel(modelId, out loadedModel)) + if (TryCloneLoadedModel(modelId, out loadedModel)) { model.Dispose(); return loadedModel; } } - // Increment the model reference count while this model exists (newly created) - // DangerousAddRef throws if it fails, so there is no need to check "success" - // Do this here since we're passing this to the caller to own and it's not done as part of the normal weight creation - var refSuccess = false; - model.NativeHandle.DangerousAddRef(ref refSuccess); - - _loadedModelCache.Add(modelId, model.NativeHandle); + _loadedModelCache.Add(modelId, new CachedModelReference + { + Model = model, + RefCount = 1 + }); return model; } @@ -79,12 +87,12 @@ public async Task LoadModelAsync(ModelFileMetadata metadata, /// public bool UnloadModel(string modelId) { - if (_loadedModelCache.TryGetValue(modelId, out var handle)) + if (_loadedModelCache.TryGetValue(modelId, out var cachedModel)) { // Decrement refcount on model - handle.DangerousRelease(); - handle.Dispose(); - if (handle.IsClosed || handle.IsInvalid) + cachedModel.Model.Dispose(); // this only disposes the original model... + cachedModel.RefCount--; + if (cachedModel.RefCount == 0) { return _loadedModelCache.Remove(modelId); } @@ -98,8 +106,10 @@ public void UnloadAllModels() { foreach (var handle in _loadedModelCache.Values) { - handle.DangerousRelease(); - handle.Dispose(); + for (var i = 0; i < handle.RefCount; i++) + { + handle.Model.Dispose(); + } } _loadedModelCache.Clear(); } From 76412ee7c30fde3b1147adba67dab78da59070fa Mon Sep 17 00:00:00 2001 From: Pat Hov Date: Mon, 17 Jun 2024 09:37:28 -0700 Subject: [PATCH 08/11] typpo --- LLama/LLamaWeights.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 13cd31d34..b5fcb8e2a 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -108,7 +108,7 @@ private LLamaWeights(SafeLlamaModelHandle handle, IReadOnlyDictionary /// Create a new instance of the model using same NativeHandle as this model. - /// Metadata is also copied from the existing model rather than read from the hanlde directly + /// Metadata is also copied from the existing model rather than read from the handle directly /// The `SafeLlamaModelHandle` will not be disposed and the model will not be unloaded until ALL such handles have been disposed. /// /// From a428a7d91bfa73380280043ee9438b58799f89fe Mon Sep 17 00:00:00 2001 From: Patrick Hovsepian Date: Mon, 17 Jun 2024 14:27:50 -0700 Subject: [PATCH 09/11] Better Cache Semantics (#8) cache semantics --- LLama.Unittest/Model/ModelCacheTests.cs | 114 +++++++++++++----------- LLama/LLamaWeights.cs | 7 +- LLama/Model/IModelCache.cs | 29 ++++-- LLama/Model/ModelCache.cs | 100 +++++++++++---------- 4 files changed, 141 insertions(+), 109 deletions(-) diff --git a/LLama.Unittest/Model/ModelCacheTests.cs b/LLama.Unittest/Model/ModelCacheTests.cs index 80dd18deb..af9275d9b 100644 --- a/LLama.Unittest/Model/ModelCacheTests.cs +++ b/LLama.Unittest/Model/ModelCacheTests.cs @@ -6,7 +6,7 @@ namespace LLama.Unittest.Model; public class ModelManagerTests { private readonly IModelSourceRepo _testRepo = new FileSystemModelRepo([Constants.ModelDirectory]); - + private readonly ModelCache TestableModelManager; public ModelManagerTests() @@ -17,97 +17,109 @@ public ModelManagerTests() [Fact] public async void LoadModel_DisposesOnUnload() { + const string modelId = "llama-2-7b"; var modelToLoad = _testRepo.GetAvailableModels() - .First(f => f.ModelFileName.Contains("llama-2-7b")); + .First(f => f.ModelFileName.Contains(modelId)); - var model = await TestableModelManager.LoadModelAsync(modelToLoad); + // Load success + var model = await TestableModelManager.LoadModelAsync(modelToLoad, modelId); Assert.NotNull(model); + Assert.Equal(1, TestableModelManager.ModelsCached()); - // unloaded and disposed` - Assert.True(TestableModelManager.UnloadModel(model.ModelName)); + // Load with same Id throws + await Assert.ThrowsAsync(async () => + { + await TestableModelManager.LoadModelAsync(modelToLoad, modelId); + return; + }); + Assert.Equal(1, TestableModelManager.ModelsCached()); + + // unloaded and disposed + Assert.True(TestableModelManager.UnloadModel(modelId)); Assert.Throws(() => { _ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri)); }); + Assert.Equal(0, TestableModelManager.ModelsCached()); - // wont unload and already - Assert.False(TestableModelManager.UnloadModel(model.ModelName)); + // already unloaded and disposed + Assert.False(TestableModelManager.UnloadModel(modelId)); Assert.Throws(() => { _ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri)); }); + + // Can be reloaded after unload + model = await TestableModelManager.LoadModelAsync(modelToLoad, modelId); + Assert.NotNull(model); + Assert.Equal(1, TestableModelManager.ModelsCached()); + Assert.True(TestableModelManager.UnloadModel(modelId)); + Assert.Equal(0, TestableModelManager.ModelsCached()); } [Fact] - public async void LoadModel_LoadsAndCaches() + public async void TryCloneLoadedModel_ClonesAndCaches() { + const string modelId = "llama-2-7b"; var modelToLoad = _testRepo.GetAvailableModels() - .First(f => f.ModelFileName.Contains("llama-2-7b")); + .First(f => f.ModelFileName.Contains(modelId)); - // Create Model -- Ref 1 - var model = await TestableModelManager.LoadModelAsync(modelToLoad); + var model = await TestableModelManager.LoadModelAsync(modelToLoad, modelId); Assert.NotNull(model); + Assert.Equal(1, TestableModelManager.ModelsCached()); // clone it -- Ref 2 - var isCachedAndCloned = TestableModelManager.TryCloneLoadedModel(model.ModelName, out var cachedModel); + const string cloneId = nameof(cloneId); + var isCachedAndCloned = TestableModelManager.TryCloneLoadedModel(modelId, cloneId, out var cachedModel); Assert.True(isCachedAndCloned); Assert.NotNull(cachedModel); + Assert.Equal(2, TestableModelManager.ModelsCached()); cachedModel.Dispose(); //-- ref 1 - Assert.True(TestableModelManager.UnloadModel(model.ModelName)); + Assert.True(TestableModelManager.UnloadModel(modelId)); + Assert.Equal(1, TestableModelManager.ModelsCached()); // unloaded and disposed` -- ref 2 - Assert.True(TestableModelManager.UnloadModel(model.ModelName)); + Assert.True(TestableModelManager.UnloadModel(cloneId)); + Assert.Equal(0, TestableModelManager.ModelsCached()); - Assert.False(TestableModelManager.UnloadModel(model.ModelName)); + Assert.False(TestableModelManager.UnloadModel(modelId)); + Assert.False(TestableModelManager.UnloadModel(cloneId)); Assert.Throws(() => { _ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri)); }); + Assert.Throws(() => + { + _ = cachedModel.CreateContext(new ModelParams(modelToLoad.ModelFileUri)); + }); } [Fact] - public async void LoadModel_AlreadyLoaded_ReturnsFromCache() + public async void TryCloneLoadedModel_SameId_Throws() { + const string modelId = "llama-2-7b"; var modelToLoad = _testRepo.GetAvailableModels() - .First(f => f.ModelFileName.Contains("llama-2-7b")); + .First(f => f.ModelFileName.Contains(modelId)); - for (var i = 0; i < 5; i++) - { - var model = await TestableModelManager.LoadModelAsync(modelToLoad); - Assert.NotNull(model); - Assert.Equal("LLaMA v2", model.ModelName); - var isLoaded = TestableModelManager.TryCloneLoadedModel(model.ModelName, out var cachedModel); - Assert.True(isLoaded); - Assert.NotNull(cachedModel); - Assert.Equal("LLaMA v2", cachedModel.ModelName); - } - } + var model = await TestableModelManager.LoadModelAsync(modelToLoad, modelId); + Assert.NotNull(model); + Assert.Equal(1, TestableModelManager.ModelsCached()); - [Fact] - public async void TryGetLoadedModel_AlreadyDisposed_ReturnsFalse() - { - var modelToLoad = _testRepo.GetAvailableModels() - .First(f => f.ModelFileName.Contains("llama-2-7b")); + // Same Id clone fails + Assert.Throws(() => + { + TestableModelManager.TryCloneLoadedModel(modelId, modelId, out var cachedModel); + }); + Assert.Equal(1, TestableModelManager.ModelsCached()); - using (var model = await TestableModelManager.LoadModelAsync(modelToLoad)) + // Unload and dispose + Assert.True(TestableModelManager.UnloadModel(modelId)); + Assert.Equal(0, TestableModelManager.ModelsCached()); + Assert.False(TestableModelManager.UnloadModel(modelId)); + Assert.Throws(() => { - Assert.NotNull(model); - Assert.Equal(model.ModelName, model.ModelName); - var isLoaded = TestableModelManager.TryCloneLoadedModel(model.ModelName, out var cachedModel); - Assert.True(isLoaded); - Assert.NotNull(cachedModel); - Assert.Equal(model.ModelName, cachedModel.ModelName); - - // unload from the last check - Assert.True(TestableModelManager.UnloadModel(model.ModelName)); - - } // end scope, dispose is called on the model but since we have the model cache it should stick around until unloaded - Assert.True(TestableModelManager.UnloadModel("LLaMA v2")); - - // Model is still loaded due to cache - var isDisposedLoaded = TestableModelManager.TryCloneLoadedModel("LLaMA v2", out var disposedModel); - Assert.False(isDisposedLoaded); - Assert.Null(disposedModel); + _ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri)); + }); } } diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index b5fcb8e2a..b131732dc 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -108,16 +108,13 @@ private LLamaWeights(SafeLlamaModelHandle handle, IReadOnlyDictionary /// Create a new instance of the model using same NativeHandle as this model. - /// Metadata is also copied from the existing model rather than read from the handle directly + /// Metadata is also copied from the existing model rather than read from the hanlde directly /// The `SafeLlamaModelHandle` will not be disposed and the model will not be unloaded until ALL such handles have been disposed. /// /// public LLamaWeights CloneFromHandleWithMetadata() { - var metadataClone = Metadata - .Select(x => x) - .ToDictionary(x => x.Key, x => x.Value); - return new LLamaWeights(NativeHandle, metadataClone); + return new LLamaWeights(NativeHandle, Metadata); } /// diff --git a/LLama/Model/IModelCache.cs b/LLama/Model/IModelCache.cs index df9e518fd..078586a7e 100644 --- a/LLama/Model/IModelCache.cs +++ b/LLama/Model/IModelCache.cs @@ -19,18 +19,35 @@ public interface IModelCache : IDisposable /// /// Load a model file to be used for inference. - /// The caller assumes responsibility for disposing this model and MUST call Unload + /// The caller assumes responsibility for disposing this model and MUST call UnloadModel /// - /// - /// - /// An alias to uniquely identify this model's underlying handle. If none is supplied, the model's name is used.' + /// The metadata about the model file to be loaded + /// A required alias to uniquely identify this model' + /// An optional function to further configure the model parameters beyond default /// - /// The loaded model on success + /// An instance of the newly loaded model. This MUST be disposed or Unload public Task LoadModelAsync(ModelFileMetadata metadata, + string modelId, Action? modelConfigurator = null!, - string modelId = "", CancellationToken cancellationToken = default); + /// + /// Attempt to get a reference to a model that's already loaded + /// + /// Identifier of the loaded model + /// Will be populated with the reference if the model is cached + /// A SHARED instance to a model that's already loaded. Disposing or Unloading this model will affect all references + public bool TryGetLoadedModel(string modelId, out LLamaWeights cachedModel); + + /// + /// Attempt to clone and cache a new unique model instance + /// + /// Model that's expected to be loaded and cloned + /// Identifier for the newly cloned model + /// If cloning is successful, this model will be available for use + /// True if cloning is successful + public bool TryCloneLoadedModel(string loadedModelId, string cloneId, out LLamaWeights model); + /// /// Unload and dispose of a model with the given id /// diff --git a/LLama/Model/ModelCache.cs b/LLama/Model/ModelCache.cs index 1d7eb9435..8d683326b 100644 --- a/LLama/Model/ModelCache.cs +++ b/LLama/Model/ModelCache.cs @@ -22,65 +22,80 @@ public class ModelCache : IModelCache private bool _disposed = false; // model id/alias, to loaded model - private readonly Dictionary _loadedModelCache = []; + private readonly Dictionary _loadedModelCache = []; /// public int ModelsCached() => _loadedModelCache.Count; /// - public bool TryCloneLoadedModel(string modelId, out LLamaWeights model) + public bool TryCloneLoadedModel(string loadedModelId, + string cloneId, + out LLamaWeights model) { - var isCached = _loadedModelCache.TryGetValue(modelId, out var cachedModel); + var isCached = _loadedModelCache.TryGetValue(loadedModelId, out var cachedModel); model = null!; if (isCached) { - model = cachedModel.Model.CloneFromHandleWithMetadata(); - cachedModel.RefCount++; + TryAddModel(cloneId, cachedModel.CloneFromHandleWithMetadata); + model = _loadedModelCache[loadedModelId]; + return true; } - return isCached; + return false; } /// - public async Task LoadModelAsync(ModelFileMetadata metadata, - Action? modelConfigurator = null!, - string modelId = "", - CancellationToken cancellationToken = default) + public bool TryGetLoadedModel(string modelId, out LLamaWeights cachedModel) { - // is the model already loaded? alias could be different but it's up to the caller to be consistent - if (!string.IsNullOrEmpty(modelId) - && TryCloneLoadedModel(modelId, out var loadedModel)) + return _loadedModelCache.TryGetValue(modelId, out cachedModel); + } + + private void TryAddModel(string modelId, Func modelCreator) + { + if (IsModelIdInvalid(modelId)) { - return loadedModel; + throw new ArgumentException("Model identifier is not unique"); } - // Configure model params - var modelParams = new ModelParams(metadata.ModelFileUri); - modelConfigurator?.Invoke(modelParams); - - // load and cache - var model = await LLamaWeights.LoadFromFileAsync(modelParams, cancellationToken); + _loadedModelCache.Add(modelId, modelCreator()); + } - // Check if it's already cached, if so use that and dispose of this - // TODO: Consider the case where the alias is different but the underlying model file is the same - if (string.IsNullOrWhiteSpace(modelId)) + private async Task TryAddModelAsync(string modelId, Func> modelCreator) + { + if (IsModelIdInvalid(modelId)) { - modelId = model.ModelName; - - if (TryCloneLoadedModel(modelId, out loadedModel)) - { - model.Dispose(); - return loadedModel; - } + throw new ArgumentException("Model identifier is not unique"); } - _loadedModelCache.Add(modelId, new CachedModelReference + _loadedModelCache.Add(modelId, await modelCreator()); + } + + private bool IsModelIdInvalid(string modelId) => + string.IsNullOrWhiteSpace(modelId) || _loadedModelCache.ContainsKey(modelId); + + /// + public async Task LoadModelAsync(ModelFileMetadata metadata, + string modelId, + Action? modelConfigurator = null!, + CancellationToken cancellationToken = default) + { + await TryAddModelAsync(modelId, async () => { - Model = model, - RefCount = 1 + return await ModelCreator(metadata.ModelFileUri, modelConfigurator, cancellationToken); }); - return model; + return _loadedModelCache[modelId]; + + // Helper to create the model + static async Task ModelCreator(string fileUri, + Action? modelConfigurator, + CancellationToken cancellationToken) + { + var modelParams = new ModelParams(fileUri); + modelConfigurator?.Invoke(modelParams); + + return await LLamaWeights.LoadFromFileAsync(modelParams, cancellationToken); + } } #region Unload @@ -89,14 +104,8 @@ public bool UnloadModel(string modelId) { if (_loadedModelCache.TryGetValue(modelId, out var cachedModel)) { - // Decrement refcount on model - cachedModel.Model.Dispose(); // this only disposes the original model... - cachedModel.RefCount--; - if (cachedModel.RefCount == 0) - { - return _loadedModelCache.Remove(modelId); - } - return true; + cachedModel.Dispose(); + return _loadedModelCache.Remove(modelId); } return false; } @@ -104,12 +113,9 @@ public bool UnloadModel(string modelId) /// public void UnloadAllModels() { - foreach (var handle in _loadedModelCache.Values) + foreach (var model in _loadedModelCache.Values) { - for (var i = 0; i < handle.RefCount; i++) - { - handle.Model.Dispose(); - } + model.Dispose(); } _loadedModelCache.Clear(); } From 0f352f99611aa80cad8b9ff95e5606dc9609837e Mon Sep 17 00:00:00 2001 From: pat_hov Date: Mon, 17 Jun 2024 14:29:53 -0700 Subject: [PATCH 10/11] typo fix --- LLama/LLamaWeights.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index b131732dc..d18c4d59e 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -108,7 +108,7 @@ private LLamaWeights(SafeLlamaModelHandle handle, IReadOnlyDictionary /// Create a new instance of the model using same NativeHandle as this model. - /// Metadata is also copied from the existing model rather than read from the hanlde directly + /// Metadata is also copied from the existing model rather than read from the handle directly /// The `SafeLlamaModelHandle` will not be disposed and the model will not be unloaded until ALL such handles have been disposed. /// /// From 396d3a32b390d50531a973902670816764b03ece Mon Sep 17 00:00:00 2001 From: Patrick Hovsepian Date: Wed, 19 Jun 2024 22:11:21 -0700 Subject: [PATCH 11/11] Remove comment, cleanup (#9) cleanup --- LLama/LLamaWeights.cs | 1 - LLama/Model/IModelCache.cs | 1 - LLama/Model/ModelCache.cs | 4 ---- LLama/Native/SafeLlamaModelHandle.cs | 3 +-- 4 files changed, 1 insertion(+), 8 deletions(-) diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index d18c4d59e..aeaf3f5f1 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Generic; -using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; diff --git a/LLama/Model/IModelCache.cs b/LLama/Model/IModelCache.cs index 078586a7e..093272227 100644 --- a/LLama/Model/IModelCache.cs +++ b/LLama/Model/IModelCache.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using LLama.Common; diff --git a/LLama/Model/ModelCache.cs b/LLama/Model/ModelCache.cs index 8d683326b..933156a35 100644 --- a/LLama/Model/ModelCache.cs +++ b/LLama/Model/ModelCache.cs @@ -1,12 +1,8 @@ using System; using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Linq; using System.Threading; using System.Threading.Tasks; using LLama.Common; -using LLama.Native; namespace LLama.Model; diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 5e66729da..c25f0b4f9 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -109,8 +109,7 @@ protected override bool ReleaseHandle() llama_free_model(handle); return true; } - - // TODO: Move this to the model manager? + /// /// Load a model from the given file path into memory ///