From efe4725e569384d6080ade88a4eec689a75b8202 Mon Sep 17 00:00:00 2001 From: sir_Andros <andros.osipkov@gmail.com> Date: Mon, 18 Feb 2019 20:05:49 +0100 Subject: [PATCH] #12 Handle secure desktop. --- KeePassWinHelloExt.cs | 2 +- KeyManagement/KeyManager.cs | 73 ++++++++++++++++++++++++++++++------- KeyManagement/KeyStorage.cs | 7 ++++ 3 files changed, 68 insertions(+), 14 deletions(-) diff --git a/KeePassWinHelloExt.cs b/KeePassWinHelloExt.cs index 2fdc8f4..57419a8 100644 --- a/KeePassWinHelloExt.cs +++ b/KeePassWinHelloExt.cs @@ -53,7 +53,7 @@ private void OnWindowAdded(object sender, GwmWindowEventArgs e) var keyPromptForm = e.Form as KeyPromptForm; if (keyPromptForm != null) { - _keyManager.OnKeyPrompt(keyPromptForm); + _keyManager.OnKeyPrompt(keyPromptForm, _host.MainWindow); return; } diff --git a/KeyManagement/KeyManager.cs b/KeyManagement/KeyManager.cs index cad6ca5..b761823 100644 --- a/KeyManagement/KeyManager.cs +++ b/KeyManagement/KeyManager.cs @@ -2,6 +2,8 @@ using System.Diagnostics; using System.Linq; using System.Reflection; +using System.Threading; +using System.Threading.Tasks; using System.Windows.Forms; using KeePass.Forms; using KeePassLib.Keys; @@ -13,6 +15,7 @@ class KeyManager { private readonly KeyCipher _keyCipher; private readonly KeyStorage _keyStorage; + private bool _isSecureDesktopSettingChanged = false; public KeyManager(IntPtr windowHandle) { @@ -20,24 +23,46 @@ public KeyManager(IntPtr windowHandle) _keyCipher = new KeyCipher(Settings.ConfirmationMessage, windowHandle); } - public void OnKeyPrompt(KeyPromptForm keyPromptForm) + public void OnKeyPrompt(KeyPromptForm keyPromptForm, MainForm mainWindow) { - if (keyPromptForm.SecureDesktopMode) - return; - if (!Settings.Instance.Enabled) return; - CompositeKey compositeKey; - if (ExtractCompositeKey(GetDbPath(keyPromptForm), out compositeKey)) + string dbPath = GetDbPath(keyPromptForm); + if (keyPromptForm.SecureDesktopMode) { - SetCompositeKey(keyPromptForm, compositeKey); - // Remove flushing - keyPromptForm.Visible = false; - keyPromptForm.Opacity = 0; - - keyPromptForm.DialogResult = DialogResult.OK; - keyPromptForm.Close(); + if (IsKeyForDataBaseExist(dbPath)) + { + var dbFile = GetIoInfo(keyPromptForm); + CloseFormWithResult(keyPromptForm, DialogResult.Cancel); + Task.Factory.StartNew(() => + { + KeePass.Program.Config.Security.MasterKeyOnSecureDesktop = false; + _isSecureDesktopSettingChanged = true; + Thread.Yield(); + ReOpenKeyPromptForm(mainWindow, dbFile); + }) + .ContinueWith(_ => + { + KeePass.Program.Config.Security.MasterKeyOnSecureDesktop = true; + _isSecureDesktopSettingChanged = false; + }); + } + } + else + { + CompositeKey compositeKey; + if (ExtractCompositeKey(dbPath, out compositeKey)) + { + SetCompositeKey(keyPromptForm, compositeKey); + CloseFormWithResult(keyPromptForm, DialogResult.OK); + } + else if (_isSecureDesktopSettingChanged) + { + var dbFile = GetIoInfo(keyPromptForm); + CloseFormWithResult(keyPromptForm, DialogResult.Cancel); + Task.Factory.StartNew(() => ReOpenKeyPromptForm(mainWindow, dbFile)); + } } } @@ -68,6 +93,28 @@ public void OnDBClosing(object sender, FileClosingEventArgs e) } } + private static void CloseFormWithResult(KeyPromptForm keyPromptForm, DialogResult result) + { + // Remove flushing + keyPromptForm.Visible = false; + keyPromptForm.Opacity = 0; + + keyPromptForm.DialogResult = result; + keyPromptForm.Close(); + } + + private static void ReOpenKeyPromptForm(MainForm mainWindow, IOConnectionInfo dbFile) + { + Action action = () => mainWindow.OpenDatabase(dbFile, null, false); + mainWindow.Invoke(action); + } + + private bool IsKeyForDataBaseExist(string dbPath) + { + return !String.IsNullOrEmpty(dbPath) + && _keyStorage.ContainsKey(dbPath); + } + private bool ExtractCompositeKey(string dbPath, out CompositeKey compositeKey) { compositeKey = null; diff --git a/KeyManagement/KeyStorage.cs b/KeyManagement/KeyStorage.cs index a98a0c4..f4a2534 100644 --- a/KeyManagement/KeyStorage.cs +++ b/KeyManagement/KeyStorage.cs @@ -40,6 +40,13 @@ public void Remove(string dbPath) _keys.Remove(dbPath); } + public bool ContainsKey(string dbPath) + { + Data data; + return _keys.TryGetValue(dbPath, out data) + && data.IsValid(); + } + public bool TryGetValue(string dbPath, out ProtectedKey protectedKey) { Data data;