Skip to content

Commit

Permalink
fix the library
Browse files Browse the repository at this point in the history
  • Loading branch information
cminus01 committed Oct 7, 2022
1 parent 3cd1114 commit 0d0400c
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 124 deletions.
91 changes: 52 additions & 39 deletions graph4nlp/pytorch/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ..modules.utils.tree_utils import Tree
from ..modules.utils.tree_utils import Vocab as VocabForTree
from ..modules.utils.tree_utils import VocabForAll
from ..modules.utils.vocab_utils import VocabModel
from ..modules.utils.vocab_utils import Vocab, VocabModel


class DataItem(object):
Expand Down Expand Up @@ -146,6 +146,16 @@ def extract(self):
output_tokens = self.tokenizer(self.output_text)

return input_tokens, output_tokens

def extract_edge_tokens(self):
g: GraphData = self.graph
edge_tokens = []
for i in range(g.get_edge_num()):
if "token" in g.edge_attributes[i]:
edge_tokens.append(g.edge_attributes[i]["token"])
else:
edge_tokens.append("")
return edge_tokens


class Text2LabelDataItem(DataItem):
Expand Down Expand Up @@ -312,6 +322,7 @@ def __init__(
reused_vocab_model=None,
nlp_processor_args=None,
init_edge_vocab=False,
is_hetero=False,
**kwargs,
):
"""
Expand Down Expand Up @@ -358,6 +369,10 @@ def __init__(
vocabulary data is located.
nlp_processor_args: dict, default=None
It contains the parameter for nlp processor such as ``stanza``.
init_edge_vocab: bool, default=False
Whether to initialize the edge vocabulary.
is_hetero: bool, default=False
Whether the graph is heterogeneous.
kwargs
"""
super(Dataset, self).__init__()
Expand Down Expand Up @@ -387,6 +402,7 @@ def __init__(
self.topology_subdir = topology_subdir
self.use_val_for_vocab = use_val_for_vocab
self.init_edge_vocab = init_edge_vocab
self.is_hetero = is_hetero
for k, v in kwargs.items():
setattr(self, k, v)
self.__indices__ = None
Expand Down Expand Up @@ -428,8 +444,6 @@ def __init__(

vocab = torch.load(self.processed_file_paths["vocab"])
self.vocab_model = vocab
if init_edge_vocab:
self.edge_vocab = torch.load(self.processed_file_paths["edge_vocab"])

if hasattr(self, "reused_label_model"):
self.label_model = LabelModel.build(self.processed_file_paths["label"])
Expand Down Expand Up @@ -663,6 +677,7 @@ def build_vocab(self):
target_pretrained_word_emb_name=self.target_pretrained_word_emb_name,
target_pretrained_word_emb_url=self.target_pretrained_word_emb_url,
word_emb_size=self.word_emb_size,
init_edge_vocab=self.init_edge_vocab,
)
self.vocab_model = vocab_model

Expand Down Expand Up @@ -709,41 +724,6 @@ def _process(self):
self.test = self.build_topology(self.test)
if "val" in self.__dict__:
self.val = self.build_topology(self.val)
# build_edge_vocab and save
if self.init_edge_vocab:
self.edge_vocab = {}
s = set()
try:
for i in self.train:
graph = i.graph
for edge_idx in range(graph.get_edge_num()):
if "token" in graph.edge_attributes[edge_idx]:
edge_token = graph.edge_attributes[edge_idx]["token"]
s.add(edge_token)
except Exception as e:
pass
try:
for i in self.test:
graph = i.graph
for edge_idx in range(graph.get_edge_num()):
if "token" in graph.edge_attributes[edge_idx]:
edge_token = graph.edge_attributes[edge_idx]["token"]
s.add(edge_token)
except Exception as e:
pass
try:
for i in self.val:
graph = i.graph
for edge_idx in range(graph.get_edge_num()):
if "token" in graph.edge_attributes[edge_idx]:
edge_token = graph.edge_attributes[edge_idx]["token"]
s.add(edge_token)
except Exception as e:
pass
s.add("")
self.edge_vocab = {v: k for k, v in enumerate(s)}
print('edge vocab size:', len(self.edge_vocab))
torch.save(self.edge_vocab, self.processed_file_paths["edge_vocab"])

self.build_vocab()

Expand Down Expand Up @@ -1116,6 +1096,11 @@ def build_vocab(self):
pretrained_word_emb_cache_dir=self.pretrained_word_emb_cache_dir,
embedding_dims=self.dec_emb_size,
)
if self.init_edge_vocab:
all_edge_words = VocabModel.collect_edge_vocabs(data_for_vocab, self.tokenizer, lower_case=self.lower_case)
edge_vocab = Vocab(lower_case=self.lower_case, tokenizer=self.tokenizer)
edge_vocab.build_vocab(all_edge_words, max_vocab_size=None, min_vocab_freq=1)
edge_vocab.randomize_embeddings(self.word_emb_size)

if self.share_vocab:
all_words = Counter()
Expand Down Expand Up @@ -1158,6 +1143,7 @@ def build_vocab(self):
in_word_vocab=src_vocab_model,
out_word_vocab=tgt_vocab_model,
share_vocab=src_vocab_model if self.share_vocab else None,
edge_vocab=edge_vocab if self.init_edge_vocab else None,
)

return self.vocab_model
Expand All @@ -1175,6 +1161,18 @@ def vectorization(self, data_items):
token_matrix = torch.tensor(token_matrix, dtype=torch.long)
graph.node_features["token_id"] = token_matrix

if self.is_hetero:
for edge_idx in range(graph.get_edge_num()):
if "token" in graph.edge_attributes[edge_idx]:
edge_token = graph.edge_attributes[edge_idx]["token"]
else:
edge_token = ""
edge_token_id = self.edge_vocab[edge_token]
graph.edge_attributes[edge_idx]["token_id"] = edge_token_id
token_matrix.append([edge_token_id])
token_matrix = torch.tensor(token_matrix, dtype=torch.long)
graph.edge_features["token_id"] = token_matrix

tgt = item.output_text
tgt_list = self.vocab_model.out_word_vocab.get_symbol_idx_for_list(tgt.split())
output_tree = Tree.convert_to_tree(
Expand All @@ -1183,7 +1181,7 @@ def vectorization(self, data_items):
item.output_tree = output_tree

@classmethod
def _vectorize_one_dataitem(cls, data_item, vocab_model, use_ie=False):
def _vectorize_one_dataitem(cls, data_item, vocab_model, use_ie=False, is_hetero=False):
item = deepcopy(data_item)
graph: GraphData = item.graph
token_matrix = []
Expand All @@ -1195,6 +1193,21 @@ def _vectorize_one_dataitem(cls, data_item, vocab_model, use_ie=False):
token_matrix = torch.tensor(token_matrix, dtype=torch.long)
graph.node_features["token_id"] = token_matrix

if is_hetero:
if not hasattr(vocab_model, "edge_vocab"):
raise ValueError("Vocab model must have edge vocab attribute")
token_matrix = []
for edge_idx in range(graph.get_edge_num()):
if "token" in graph.edge_attributes[edge_idx]:
edge_token = graph.edge_attributes[edge_idx]["token"]
else:
edge_token = ""
edge_token_id = vocab_model.edge_vocab[edge_token]
graph.edge_attributes[edge_idx]["token_id"] = edge_token_id
token_matrix.append([edge_token_id])
token_matrix = torch.tensor(token_matrix, dtype=torch.long)
graph.edge_features["token_id"] = token_matrix

if isinstance(item.output_text, str):
tgt = item.output_text
tgt_list = vocab_model.out_word_vocab.get_symbol_idx_for_list(tgt.split())
Expand Down
2 changes: 2 additions & 0 deletions graph4nlp/pytorch/datasets/mawps.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
for_inference=False,
reused_vocab_model=None,
init_edge_vocab=False,
is_hetero=False,
):
"""
Parameters
Expand Down Expand Up @@ -120,4 +121,5 @@ def __init__(
for_inference=for_inference,
reused_vocab_model=reused_vocab_model,
init_edge_vocab=init_edge_vocab,
is_hetero=is_hetero,
)
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,6 @@ def __init__(
"w2v_bert",
"w2v_bert_bilstm",
"w2v_bert_bigru",
"w2v_amr",
"w2v_bilstm_amr",
"w2v_bilstm_amr_pos",
), "emb_strategy must be one of ('w2v', 'w2v_bilstm', 'w2v_bigru', 'bert', 'bert_bilstm', "
"'bert_bigru', 'w2v_bert', 'w2v_bert_bilstm', 'w2v_bert_bigru')"

Expand All @@ -163,14 +160,6 @@ def __init__(
seq_info_encode_strategy = "none"
else:
seq_info_encode_strategy = "none"
if "amr" in emb_strategy:
seq_info_encode_strategy = "bilstm"

if "pos" in emb_strategy:
word_emb_type.add("pos")
#word_emb_type.add("entity_label")
word_emb_type.add("position")

if "w2v" in emb_strategy:
word_emb_type.add("w2v")

Expand Down Expand Up @@ -222,17 +211,6 @@ def __init__(
else:
rnn_input_size = word_emb_size

if "pos" in word_emb_type:
self.word_emb_layers["pos"] = WordEmbedding(37, 50)
rnn_input_size += 50

if "entity_label" in word_emb_type:
self.word_emb_layers["entity_label"] = WordEmbedding(26, 50)
rnn_input_size += 50

if "position" in word_emb_type:
pass

if "seq_bert" in word_emb_type:
rnn_input_size += self.word_emb_layers["seq_bert"].bert_model.config.hidden_size

Expand All @@ -250,8 +228,6 @@ def __init__(
else:
self.output_size = rnn_input_size
self.seq_info_encode_layer = None

#self.fc = nn.Linear(376, 300)

def forward(self, batch_gd):
"""Compute initial node/edge embeddings.
Expand Down Expand Up @@ -284,60 +260,6 @@ def forward(self, batch_gd):
word_feat, self.word_dropout, shared_axes=[-2], training=self.training
)
feat.append(word_feat)
if any(batch_gd.batch_graph_attributes):
tot = 0
gd_list = from_batch(batch_gd)
for i, g in enumerate(gd_list):
sentence_id = g.graph_attributes["sentence_id"].to(batch_gd.device)
seq_feat = []
if "w2v" in self.word_emb_layers:
word_feat = self.word_emb_layers["w2v"](sentence_id)
word_feat = dropout_fn(
word_feat, self.word_dropout, shared_axes=[-2], training=self.training
)
seq_feat.append(word_feat)
else:
RuntimeError("No word embedding layer")
if "pos" in self.word_emb_layers:
sentence_pos = g.graph_attributes["pos_tag_id"].to(batch_gd.device)
pos_feat = self.word_emb_layers["pos"](sentence_pos)
pos_feat = dropout_fn(
pos_feat, self.word_dropout, shared_axes=[-2], training=self.training
)
seq_feat.append(pos_feat)

if "entity_label" in self.word_emb_layers:
sentence_entity_label = g.graph_attributes["entity_label_id"].to(batch_gd.device)
entity_label_feat = self.word_emb_layers["entity_label"](sentence_entity_label)
entity_label_feat = dropout_fn(
entity_label_feat, self.word_dropout, shared_axes=[-2], training=self.training
)
seq_feat.append(entity_label_feat)

seq_feat = torch.cat(seq_feat, dim=-1)

raw_tokens = [dd.strip().split() for dd in g.graph_attributes["sentence"]]
l = [len(s) for s in raw_tokens]
rnn_state = self.seq_info_encode_layer(
seq_feat, torch.LongTensor(l).to(batch_gd.device)
)
if isinstance(rnn_state, (tuple, list)):
rnn_state = rnn_state[0]

# update node features
for j in range(g.get_node_num()):
id = g.node_attributes[j]["sentence_id"]
if g.node_attributes[j]["id"] in batch_gd.batch_graph_attributes[i]["mapping"][id]:
rel_list = batch_gd.batch_graph_attributes[i]["mapping"][id][g.node_attributes[j]["id"]]
state = []
for rel in rel_list:
if rel[1] == "node":
state.append(rnn_state[id][rel[0]])
# replace embedding of the node
if len(state) > 0:
feat[0][tot + j][0] = torch.stack(state, 0).mean(0)

tot += g.get_node_num()

if "node_edge_bert" in self.word_emb_layers:
input_data = [
Expand All @@ -352,17 +274,14 @@ def forward(self, batch_gd):

if len(feat) > 0:
feat = torch.cat(feat, dim=-1)
if not any(batch_gd.batch_graph_attributes):
node_token_lens = torch.clamp((token_ids != Vocab.PAD).sum(-1), min=1)
feat = self.node_edge_emb_layer(feat, node_token_lens)
else:
feat = feat.squeeze(dim=1)
node_token_lens = torch.clamp((token_ids != Vocab.PAD).sum(-1), min=1)
feat = self.node_edge_emb_layer(feat, node_token_lens)
if isinstance(feat, (tuple, list)):
feat = feat[-1]

feat = batch_gd.split_features(feat)

if (self.seq_info_encode_layer is None and "seq_bert" not in self.word_emb_layers) or any(batch_gd.batch_graph_attributes):
if self.seq_info_encode_layer is None and "seq_bert" not in self.word_emb_layers:
if isinstance(feat, list):
feat = torch.cat(feat, -1)

Expand Down
3 changes: 2 additions & 1 deletion graph4nlp/pytorch/modules/utils/tree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,11 @@ def convert_to_tree(r_list, i_left, i_right, tgt_vocab):


class VocabForAll:
def __init__(self, in_word_vocab, out_word_vocab, share_vocab):
def __init__(self, in_word_vocab, out_word_vocab, share_vocab, edge_vocab=None):
self.in_word_vocab = in_word_vocab
self.out_word_vocab = out_word_vocab
self.share_vocab = share_vocab
self.edge_vocab = edge_vocab

def get_vocab_size(self):
if hasattr(self, "share_vocab"):
Expand Down
17 changes: 17 additions & 0 deletions graph4nlp/pytorch/modules/utils/vocab_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class VocabModel(object):
Word embedding size, default: ``None``.
share_vocab : boolean
Specify whether to share vocab between input and output text, default: ``True``.
init_edge_vocab: boolean
Specify whether to initialize edge vocab, default: ``False``.
Examples
-------
Expand Down Expand Up @@ -82,6 +84,7 @@ def __init__(
# pretrained_word_emb_file=None,
word_emb_size=None,
share_vocab=True,
init_edge_vocab=False,
):
super(VocabModel, self).__init__()
self.tokenizer = tokenizer
Expand Down Expand Up @@ -150,6 +153,12 @@ def __init__(
self.out_word_vocab.randomize_embeddings(word_emb_size)
else:
self.out_word_vocab = self.in_word_vocab

if init_edge_vocab:
all_edge_words = VocabModel.collect_edge_vocabs(data_set, self.tokenizer, lower_case=lower_case)
self.edge_vocab = Vocab(lower_case=lower_case, tokenizer=self.tokenizer)
self.edge_vocab.build_vocab(all_edge_words, max_vocab_size=None, min_vocab_freq=1)
self.edge_vocab.randomize_embeddings(word_emb_size)

if share_vocab:
print("[ Initialized word embeddings: {} ]".format(self.in_word_vocab.embeddings.shape))
Expand Down Expand Up @@ -265,6 +274,14 @@ def collect_vocabs(all_instances, tokenizer, lower_case=True, share_vocab=True):
all_words[1].update(extracted_tokens[1])

return all_words
@staticmethod
def collect_edge_vocabs(all_instances, tokenizer, lower_case=True):
"""Count vocabulary tokens for edge."""
all_edges = Counter()
for instance in all_instances:
extracted_edge_tokens = instance.extract_edge_tokens()
all_edges.update(extracted_edge_tokens)
return all_edges


class WordEmbModel(Vectors):
Expand Down

0 comments on commit 0d0400c

Please sign in to comment.