-
Notifications
You must be signed in to change notification settings - Fork 6
/
test.py
33 lines (28 loc) · 1.15 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from shimeji import ChatBot
from shimeji.model_provider import ModelProvider, Sukima_ModelProvider, ModelGenRequest, ModelGenArgs, ModelSampleArgs
from shimeji.preprocessor import ContextPreprocessor
from shimeji.postprocessor import NewlinePrunerPostprocessor
gen_args = ModelGenArgs(max_length=100, min_length=1, eos_token_id=198)
sample_args = ModelSampleArgs(temp=0.75, top_p=0.725, typical_p=0.95, rep_p=1.125)
model_args = ModelGenRequest(model='c1-6B-8bit', prompt='', sample_args=sample_args, gen_args=gen_args)
model_provider = Sukima_ModelProvider(
'http://192.168.0.147:8000',
username='username',
password='password',
args=model_args
)
bot_name = input('Enter a bot name:')
chatbot = ChatBot(
name=bot_name,
model_provider=model_provider,
preprocessors=[ContextPreprocessor()],
postprocessors=[NewlinePrunerPostprocessor()]
)
while True:
try:
user_input = input('User: ')
response = chatbot.respond('User: ' + user_input, push_chain=True)
print(f'{bot_name}:{response}')
except KeyboardInterrupt:
print('\n==Conversation Chain==\n', '\n'.join(chatbot.conversation_chain))
break