Skip to content

Commit

Permalink
add stop button
Browse files Browse the repository at this point in the history
  • Loading branch information
Natakout committed Apr 14, 2023
1 parent a8921d7 commit b6bbb71
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 44 deletions.
78 changes: 49 additions & 29 deletions lib/lib.dart
Original file line number Diff line number Diff line change
Expand Up @@ -214,25 +214,28 @@ class Lib {
ReceivePort isolateReceivePort = ReceivePort();
SendPort isolateSendPort = isolateReceivePort.sendPort;
mainSendPort.send(isolateSendPort);

var completer = Completer<ParsingDemand>();
try {
isolateReceivePort.listen((message) async {
if (message is MessageNewPrompt) {
interaction.complete(message.prompt);
}
if (message is ParsingDemand) {
// mainSendPort.send(ParsingResult(fileSaved.path));
binaryIsolate(
parsingDemand: message,
stopToken: message.stopToken,
mainSendPort: mainSendPort,
);
completer.complete(message);
}
if (message is MessageStopGeneration) {
log("[isolate] Stopping generation");
stopGeneration = true;
}
});

var parsingDemand = await completer.future;
Future.sync(() => Lib().binaryIsolate(
parsingDemand: parsingDemand,
stopToken: parsingDemand.stopToken,
mainSendPort: mainSendPort,
) as FutureOr<void>);
} catch (e) {
mainSendPort.send("[isolate] ERROR : $e");
}
Expand All @@ -252,6 +255,7 @@ class Lib {
required void Function(String log) printLog,
required String promptPassed,
required void Function() done,
required void Function() canStop,
required String stopToken,
required ParamsLlamaValuesOnly paramsLlamaValuesOnly,
}) async {
Expand Down Expand Up @@ -296,6 +300,8 @@ class Lib {
printLnLog(message.message);
} else if (message is MessageCanPrompt) {
done();
} else if (message is MessageCanStop) {
canStop();
} else {
print(message);
}
Expand Down Expand Up @@ -328,7 +334,7 @@ class Lib {

static Completer interaction = Completer();

static binaryIsolate({
binaryIsolate({
required ParsingDemand parsingDemand,
required SendPort mainSendPort,
required String stopToken,
Expand Down Expand Up @@ -480,6 +486,7 @@ class Lib {

var inp_pfx = tokenize(llamaBinded, ctx, "\n\n### Instruction:\n\n", true);
var inp_sfx = tokenize(llamaBinded, ctx, "\n\n### Response:\n\n", false);
var user_token = tokenize(llamaBinded, ctx, "\n$stopTokenTrimed", true);
var llama_token_newline = tokenize(llamaBinded, ctx, "\n", false);

var embd = Vector<Int>(nullptr, 0);
Expand All @@ -496,6 +503,8 @@ class Lib {
int remaining_tokens = gptParams.ref.n_predict;
int n_past = 0;
log('before while loop');
mainSendPort.send(MessageCanStop());

while ((remaining_tokens > 0 || gptParams.ref.interactive)) {
log('remaining tokens : $remaining_tokens');
log('stopGeneration : $stopGeneration');
Expand All @@ -506,15 +515,17 @@ class Lib {
log("error llama_eval");
return;
}
await Future.delayed(Duration(milliseconds: 1));
}
n_past += embd.length;
embd.clear();
if (stopGeneration) {
log('stop Generation initiated by user');
interaction = Completer();
stopGeneration = false;
embd.insertVectorAtEnd(user_token);
}
if (embd_inp.length <= input_consumed) {
if (embd_inp.length <= input_consumed &&
interaction.isCompleted == true) {
// out of user input, sample next token
var top_k = gptParams.ref.top_k;
var top_p = gptParams.ref.top_p;
Expand Down Expand Up @@ -571,27 +582,32 @@ class Lib {
log('input_noecho = $input_noecho embd.length = ${embd.length}');
if (!input_noecho) {
for (int i = 0; i < embd.length; ++i) {
int id = embd.pointer[i];
var str = llamaBinded
.llama_token_to_str(ctx, id)
.cast<Utf8>()
.toDartString();
logInline(str);
ttlString += str;
if (ttlString.length >= stopTokenLength &&
ttlString.length > prompt.length &&
stopTokenLength > 0) {
var lastPartTtlString = ttlString
.trim()
.substring(ttlString.trim().length - stopTokenLength - 1)
.toLowerCase()
.replaceAll(' ', '');
log('lastPartTtlString = $lastPartTtlString , stopTokenTrimed = $stopTokenTrimed');
if (lastPartTtlString == stopTokenTrimed.toLowerCase()) {
log('is_interacting = true');
interaction = Completer();
break;
try {
int id = embd.pointer[i];
var str = llamaBinded
.llama_token_to_str(ctx, id)
.cast<Utf8>()
.toDartString();
logInline(str);
ttlString += str;
if (ttlString.length >= stopTokenLength &&
ttlString.length > prompt.length &&
stopTokenLength > 0) {
var lastPartTtlString = ttlString
.trim()
.substring(ttlString.trim().length - stopTokenLength - 1)
.toLowerCase()
.replaceAll(' ', '')
.trim();
log('lastPartTtlString = $lastPartTtlString , stopTokenTrimed = ${stopTokenTrimed.toLowerCase()}, equal = ${lastPartTtlString == stopTokenTrimed.toLowerCase()}');
if (lastPartTtlString == stopTokenTrimed.toLowerCase()) {
log('is_interacting = true');
interaction = Completer();
break;
}
}
} catch (e) {
interaction = Completer();
}
}
}
Expand Down Expand Up @@ -684,6 +700,10 @@ class MessageCanPrompt {
MessageCanPrompt();
}

class MessageCanStop {
MessageCanStop();
}

class MessageNewPrompt {
final String prompt;

Expand Down
38 changes: 23 additions & 15 deletions lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,19 @@ class _MyHomePageState extends State<MyHomePage> {
scrollDown();
}

bool canStop = false;
void done() {
setState(() {
inProgress = false;
});
}

void canUseStop() {
setState(() {
canStop = true;
});
}

void _exec() {
//close the keyboard if on mobile
if (Platform.isAndroid || Platform.isIOS) {
Expand Down Expand Up @@ -220,6 +227,7 @@ class _MyHomePageState extends State<MyHomePage> {
promptController.text.trim() +
(promptController.text.isEmpty ? "" : "\n"),
done: done,
canStop: canUseStop,
stopToken: reversePromptController.text,
);
} else {
Expand Down Expand Up @@ -1434,9 +1442,9 @@ class _MyHomePageState extends State<MyHomePage> {
children: [
//top right button to copy the result
ConstrainedBox(
constraints: const BoxConstraints(
maxHeight: 200,
),
constraints: BoxConstraints(
maxHeight:
MediaQuery.of(context).size.height - 200),
child: Padding(
padding: const EdgeInsets.all(8.0),
child: SingleChildScrollView(
Expand Down Expand Up @@ -1553,18 +1561,18 @@ class _MyHomePageState extends State<MyHomePage> {
),
),
const SizedBox(width: 5),
// if (inProgress)
// ElevatedButton(
// onPressed: _cancel,
// style: ElevatedButton.styleFrom(
// padding: const EdgeInsets.symmetric(
// horizontal: 5, vertical: 5),
// ),
// child: const Icon(
// Icons.stop,
// color: Colors.white,
// ),
// ),
if (canStop && inProgress)
ElevatedButton(
onPressed: _cancel,
style: ElevatedButton.styleFrom(
padding: const EdgeInsets.symmetric(
horizontal: 5, vertical: 5),
),
child: const Icon(
Icons.stop,
color: Colors.white,
),
),
],
),
),
Expand Down

0 comments on commit b6bbb71

Please sign in to comment.