-
Notifications
You must be signed in to change notification settings - Fork 151
/
Copy pathgenerate.py
67 lines (53 loc) · 2.36 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
Usage: python -m scripts.generate \
--model-name evo-1-131k-base \
--prompt ACGT \
--n-samples 10 \
--n-tokens 100 \
--temperature 1. \
--top-k 4 \
--device cuda:0
Generates a sequence given a prompt. Also enables the user to specify various basic
sampling parameters.
"""
import argparse
from evo import Evo, generate
def main():
# Parse command-line arguments.
parser = argparse.ArgumentParser(description='Generate sequences using the Evo model.')
parser.add_argument('--model-name', type=str, default='evo-1-131k-base', help='Evo model name')
parser.add_argument('--prompt', type=str, default='ACGT', help='Prompt for generation')
parser.add_argument('--n-samples', type=int, default=3, help='Number of sequences to sample at once')
parser.add_argument('--n-tokens', type=int, default=100, help='Number of tokens to generate')
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature during sampling')
parser.add_argument('--top-k', type=int, default=4, help='Top K during sampling')
parser.add_argument('--top-p', type=float, default=1., help='Top P during sampling')
parser.add_argument('--cached-generation', type=bool, default=True, help='Use KV caching during generation')
parser.add_argument('--batched', type=bool, default=True, help='Use batched generation')
parser.add_argument('--prepend-bos', type=bool, default=False, help='Prepend BOS token')
parser.add_argument('--device', type=str, default='cuda:0', help='Device for generation')
parser.add_argument('--verbose', type=int, default=1, help='Verbosity level')
args = parser.parse_args()
# Load model.
evo_model = Evo(args.model_name)
model, tokenizer = evo_model.model, evo_model.tokenizer
model.to(args.device)
model.eval()
# Sample sequences.
print('Generated sequences:')
output_seqs, output_scores = generate(
[ args.prompt ] * args.n_samples,
model,
tokenizer,
n_tokens=args.n_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
cached_generation=args.cached_generation,
batched=args.batched,
prepend_bos=args.prepend_bos,
device=args.device,
verbose=args.verbose,
)
if __name__ == '__main__':
main()