-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathElmo_embedder.py
30 lines (26 loc) · 1.14 KB
/
Elmo_embedder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np
import multiprocessing.pool
from allennlp.commands.elmo import ElmoEmbedder
import torch
from pathlib import Path
from sklearn.preprocessing import LabelEncoder
class Elmo_embedder():
def __init__(self, model_dir="/public/home/hxu6/projects/degron/DegronsDB/model/uniref50_v2", weights="weights.hdf5",
options="options.json", threads=1000):
if threads == 1000:
torch.set_num_threads(multiprocessing.cpu_count() // 2)
else:
torch.set_num_threads(threads)
self.model_dir = Path(model_dir)
self.weights = self.model_dir / weights
self.options = self.model_dir / options
self.seqvec = ElmoEmbedder(self.options, self.weights, cuda_device=-1)
def elmo_embedding(self, x, start=None, stop=None):
assert start is None and stop is None, "deprecated to use start stop, please trim seqs beforehand"
if type(x[0]) == str:
x = np.array([list(i.upper()) for i in x])
embedding = self.seqvec.embed_sentences(x)
X_parsed = []
for i in embedding:
X_parsed.append(i.mean(axis=0))
return X_parsed