From b6bbb71773da6696fc521a0cf05ab1b2ebb1cbf6 Mon Sep 17 00:00:00 2001 From: Guerin Maxime Date: Fri, 14 Apr 2023 14:03:36 +0200 Subject: [PATCH] add stop button --- lib/lib.dart | 78 ++++++++++++++++++++++++++++++++------------------- lib/main.dart | 38 +++++++++++++++---------- 2 files changed, 72 insertions(+), 44 deletions(-) diff --git a/lib/lib.dart b/lib/lib.dart index bd9f7cea..cfcf46f6 100644 --- a/lib/lib.dart +++ b/lib/lib.dart @@ -214,7 +214,7 @@ class Lib { ReceivePort isolateReceivePort = ReceivePort(); SendPort isolateSendPort = isolateReceivePort.sendPort; mainSendPort.send(isolateSendPort); - + var completer = Completer(); try { isolateReceivePort.listen((message) async { if (message is MessageNewPrompt) { @@ -222,17 +222,20 @@ class Lib { } if (message is ParsingDemand) { // mainSendPort.send(ParsingResult(fileSaved.path)); - binaryIsolate( - parsingDemand: message, - stopToken: message.stopToken, - mainSendPort: mainSendPort, - ); + completer.complete(message); } if (message is MessageStopGeneration) { log("[isolate] Stopping generation"); stopGeneration = true; } }); + + var parsingDemand = await completer.future; + Future.sync(() => Lib().binaryIsolate( + parsingDemand: parsingDemand, + stopToken: parsingDemand.stopToken, + mainSendPort: mainSendPort, + ) as FutureOr); } catch (e) { mainSendPort.send("[isolate] ERROR : $e"); } @@ -252,6 +255,7 @@ class Lib { required void Function(String log) printLog, required String promptPassed, required void Function() done, + required void Function() canStop, required String stopToken, required ParamsLlamaValuesOnly paramsLlamaValuesOnly, }) async { @@ -296,6 +300,8 @@ class Lib { printLnLog(message.message); } else if (message is MessageCanPrompt) { done(); + } else if (message is MessageCanStop) { + canStop(); } else { print(message); } @@ -328,7 +334,7 @@ class Lib { static Completer interaction = Completer(); - static binaryIsolate({ + binaryIsolate({ required ParsingDemand parsingDemand, required SendPort mainSendPort, required String stopToken, @@ -480,6 +486,7 @@ class Lib { var inp_pfx = tokenize(llamaBinded, ctx, "\n\n### Instruction:\n\n", true); var inp_sfx = tokenize(llamaBinded, ctx, "\n\n### Response:\n\n", false); + var user_token = tokenize(llamaBinded, ctx, "\n$stopTokenTrimed", true); var llama_token_newline = tokenize(llamaBinded, ctx, "\n", false); var embd = Vector(nullptr, 0); @@ -496,6 +503,8 @@ class Lib { int remaining_tokens = gptParams.ref.n_predict; int n_past = 0; log('before while loop'); + mainSendPort.send(MessageCanStop()); + while ((remaining_tokens > 0 || gptParams.ref.interactive)) { log('remaining tokens : $remaining_tokens'); log('stopGeneration : $stopGeneration'); @@ -506,15 +515,17 @@ class Lib { log("error llama_eval"); return; } + await Future.delayed(Duration(milliseconds: 1)); } n_past += embd.length; embd.clear(); if (stopGeneration) { - log('stop Generation initiated by user'); interaction = Completer(); stopGeneration = false; + embd.insertVectorAtEnd(user_token); } - if (embd_inp.length <= input_consumed) { + if (embd_inp.length <= input_consumed && + interaction.isCompleted == true) { // out of user input, sample next token var top_k = gptParams.ref.top_k; var top_p = gptParams.ref.top_p; @@ -571,27 +582,32 @@ class Lib { log('input_noecho = $input_noecho embd.length = ${embd.length}'); if (!input_noecho) { for (int i = 0; i < embd.length; ++i) { - int id = embd.pointer[i]; - var str = llamaBinded - .llama_token_to_str(ctx, id) - .cast() - .toDartString(); - logInline(str); - ttlString += str; - if (ttlString.length >= stopTokenLength && - ttlString.length > prompt.length && - stopTokenLength > 0) { - var lastPartTtlString = ttlString - .trim() - .substring(ttlString.trim().length - stopTokenLength - 1) - .toLowerCase() - .replaceAll(' ', ''); - log('lastPartTtlString = $lastPartTtlString , stopTokenTrimed = $stopTokenTrimed'); - if (lastPartTtlString == stopTokenTrimed.toLowerCase()) { - log('is_interacting = true'); - interaction = Completer(); - break; + try { + int id = embd.pointer[i]; + var str = llamaBinded + .llama_token_to_str(ctx, id) + .cast() + .toDartString(); + logInline(str); + ttlString += str; + if (ttlString.length >= stopTokenLength && + ttlString.length > prompt.length && + stopTokenLength > 0) { + var lastPartTtlString = ttlString + .trim() + .substring(ttlString.trim().length - stopTokenLength - 1) + .toLowerCase() + .replaceAll(' ', '') + .trim(); + log('lastPartTtlString = $lastPartTtlString , stopTokenTrimed = ${stopTokenTrimed.toLowerCase()}, equal = ${lastPartTtlString == stopTokenTrimed.toLowerCase()}'); + if (lastPartTtlString == stopTokenTrimed.toLowerCase()) { + log('is_interacting = true'); + interaction = Completer(); + break; + } } + } catch (e) { + interaction = Completer(); } } } @@ -684,6 +700,10 @@ class MessageCanPrompt { MessageCanPrompt(); } +class MessageCanStop { + MessageCanStop(); +} + class MessageNewPrompt { final String prompt; diff --git a/lib/main.dart b/lib/main.dart index 37820ce1..320a4de7 100644 --- a/lib/main.dart +++ b/lib/main.dart @@ -131,12 +131,19 @@ class _MyHomePageState extends State { scrollDown(); } + bool canStop = false; void done() { setState(() { inProgress = false; }); } + void canUseStop() { + setState(() { + canStop = true; + }); + } + void _exec() { //close the keyboard if on mobile if (Platform.isAndroid || Platform.isIOS) { @@ -220,6 +227,7 @@ class _MyHomePageState extends State { promptController.text.trim() + (promptController.text.isEmpty ? "" : "\n"), done: done, + canStop: canUseStop, stopToken: reversePromptController.text, ); } else { @@ -1434,9 +1442,9 @@ class _MyHomePageState extends State { children: [ //top right button to copy the result ConstrainedBox( - constraints: const BoxConstraints( - maxHeight: 200, - ), + constraints: BoxConstraints( + maxHeight: + MediaQuery.of(context).size.height - 200), child: Padding( padding: const EdgeInsets.all(8.0), child: SingleChildScrollView( @@ -1553,18 +1561,18 @@ class _MyHomePageState extends State { ), ), const SizedBox(width: 5), - // if (inProgress) - // ElevatedButton( - // onPressed: _cancel, - // style: ElevatedButton.styleFrom( - // padding: const EdgeInsets.symmetric( - // horizontal: 5, vertical: 5), - // ), - // child: const Icon( - // Icons.stop, - // color: Colors.white, - // ), - // ), + if (canStop && inProgress) + ElevatedButton( + onPressed: _cancel, + style: ElevatedButton.styleFrom( + padding: const EdgeInsets.symmetric( + horizontal: 5, vertical: 5), + ), + child: const Icon( + Icons.stop, + color: Colors.white, + ), + ), ], ), ),