From efe4725e569384d6080ade88a4eec689a75b8202 Mon Sep 17 00:00:00 2001
From: sir_Andros <andros.osipkov@gmail.com>
Date: Mon, 18 Feb 2019 20:05:49 +0100
Subject: [PATCH] #12 Handle secure desktop.

---
 KeePassWinHelloExt.cs       |  2 +-
 KeyManagement/KeyManager.cs | 73 ++++++++++++++++++++++++++++++-------
 KeyManagement/KeyStorage.cs |  7 ++++
 3 files changed, 68 insertions(+), 14 deletions(-)

diff --git a/KeePassWinHelloExt.cs b/KeePassWinHelloExt.cs
index 2fdc8f4..57419a8 100644
--- a/KeePassWinHelloExt.cs
+++ b/KeePassWinHelloExt.cs
@@ -53,7 +53,7 @@ private void OnWindowAdded(object sender, GwmWindowEventArgs e)
             var keyPromptForm = e.Form as KeyPromptForm;
             if (keyPromptForm != null)
             {
-                _keyManager.OnKeyPrompt(keyPromptForm);
+                _keyManager.OnKeyPrompt(keyPromptForm, _host.MainWindow);
                 return;
             }
 
diff --git a/KeyManagement/KeyManager.cs b/KeyManagement/KeyManager.cs
index cad6ca5..b761823 100644
--- a/KeyManagement/KeyManager.cs
+++ b/KeyManagement/KeyManager.cs
@@ -2,6 +2,8 @@
 using System.Diagnostics;
 using System.Linq;
 using System.Reflection;
+using System.Threading;
+using System.Threading.Tasks;
 using System.Windows.Forms;
 using KeePass.Forms;
 using KeePassLib.Keys;
@@ -13,6 +15,7 @@ class KeyManager
     {
         private readonly KeyCipher _keyCipher;
         private readonly KeyStorage _keyStorage;
+        private bool _isSecureDesktopSettingChanged = false;
 
         public KeyManager(IntPtr windowHandle)
         {
@@ -20,24 +23,46 @@ public KeyManager(IntPtr windowHandle)
             _keyCipher = new KeyCipher(Settings.ConfirmationMessage, windowHandle);
         }
 
-        public void OnKeyPrompt(KeyPromptForm keyPromptForm)
+        public void OnKeyPrompt(KeyPromptForm keyPromptForm, MainForm mainWindow)
         {
-            if (keyPromptForm.SecureDesktopMode)
-                return;
-
             if (!Settings.Instance.Enabled)
                 return;
 
-            CompositeKey compositeKey;
-            if (ExtractCompositeKey(GetDbPath(keyPromptForm), out compositeKey))
+            string dbPath = GetDbPath(keyPromptForm);
+            if (keyPromptForm.SecureDesktopMode)
             {
-                SetCompositeKey(keyPromptForm, compositeKey);
-                // Remove flushing
-                keyPromptForm.Visible = false;
-                keyPromptForm.Opacity = 0;
-
-                keyPromptForm.DialogResult = DialogResult.OK;
-                keyPromptForm.Close();
+                if (IsKeyForDataBaseExist(dbPath))
+                {
+                    var dbFile = GetIoInfo(keyPromptForm);
+                    CloseFormWithResult(keyPromptForm, DialogResult.Cancel);
+                    Task.Factory.StartNew(() =>
+                    {
+                        KeePass.Program.Config.Security.MasterKeyOnSecureDesktop = false;
+                        _isSecureDesktopSettingChanged = true;
+                        Thread.Yield();
+                        ReOpenKeyPromptForm(mainWindow, dbFile);
+                    })
+                    .ContinueWith(_ =>
+                    {
+                        KeePass.Program.Config.Security.MasterKeyOnSecureDesktop = true;
+                        _isSecureDesktopSettingChanged = false;
+                    });
+                }
+            }
+            else
+            {
+                CompositeKey compositeKey;
+                if (ExtractCompositeKey(dbPath, out compositeKey))
+                {
+                    SetCompositeKey(keyPromptForm, compositeKey);
+                    CloseFormWithResult(keyPromptForm, DialogResult.OK);
+                }
+                else if (_isSecureDesktopSettingChanged)
+                {
+                    var dbFile = GetIoInfo(keyPromptForm);
+                    CloseFormWithResult(keyPromptForm, DialogResult.Cancel);
+                    Task.Factory.StartNew(() => ReOpenKeyPromptForm(mainWindow, dbFile));
+                }
             }
         }
 
@@ -68,6 +93,28 @@ public void OnDBClosing(object sender, FileClosingEventArgs e)
             }
         }
 
+        private static void CloseFormWithResult(KeyPromptForm keyPromptForm, DialogResult result)
+        {
+            // Remove flushing
+            keyPromptForm.Visible = false;
+            keyPromptForm.Opacity = 0;
+
+            keyPromptForm.DialogResult = result;
+            keyPromptForm.Close();
+        }
+
+        private static void ReOpenKeyPromptForm(MainForm mainWindow, IOConnectionInfo dbFile)
+        {
+            Action action = () => mainWindow.OpenDatabase(dbFile, null, false);
+            mainWindow.Invoke(action);
+        }
+
+        private bool IsKeyForDataBaseExist(string dbPath)
+        {
+            return !String.IsNullOrEmpty(dbPath)
+                && _keyStorage.ContainsKey(dbPath);
+        }
+
         private bool ExtractCompositeKey(string dbPath, out CompositeKey compositeKey)
         {
             compositeKey = null;
diff --git a/KeyManagement/KeyStorage.cs b/KeyManagement/KeyStorage.cs
index a98a0c4..f4a2534 100644
--- a/KeyManagement/KeyStorage.cs
+++ b/KeyManagement/KeyStorage.cs
@@ -40,6 +40,13 @@ public void Remove(string dbPath)
             _keys.Remove(dbPath);
         }
 
+        public bool ContainsKey(string dbPath)
+        {
+            Data data;
+            return _keys.TryGetValue(dbPath, out data)
+                && data.IsValid();
+        }
+
         public bool TryGetValue(string dbPath, out ProtectedKey protectedKey)
         {
             Data data;