diff --git a/cement/cement_document.py b/cement/cement_document.py index 4817db3..6566746 100644 --- a/cement/cement_document.py +++ b/cement/cement_document.py @@ -3,7 +3,8 @@ import logging from concrete import Communication, AnnotationMetadata, EntityMention, UUID, EntityMentionSet, EntitySet, \ - SituationMentionSet, SituationSet, SituationMention, MentionArgument, TokenizationKind, Entity + SituationMentionSet, SituationSet, SituationMention, MentionArgument, TokenizationKind, Entity, Argument, \ + Justification, TimeML, Situation from concrete.util import read_communication_from_file, \ add_references_to_communication, write_communication_to_file, read_thrift_from_file from concrete.validate import validate_communication @@ -258,12 +259,13 @@ def get_entity_mention_by_indices(self, start: int, end: int) -> Optional[Entity def add_entity_singleton(self, mention: Union[EntityMention, CementEntityMention], entity_type: str, + entity_id: Optional[str] = None, confidence: float = 1., update: bool = True) -> UUID: # TODO(@Yunmo): this assumption might not always hold, please visit later entity_mention_uuid = self.add_entity_mention(mention=mention, update=update) entity = Entity(uuid=augf.next(), - id=None, + id=entity_id, mentionIdList=[entity_mention_uuid], rawMentionList=None, type=entity_type, @@ -272,6 +274,27 @@ def add_entity_singleton(self, self._entity_set.entityList.append(entity) return entity.uuid + def add_entity(self, + mentions: List[Union[EntityMention, CementEntityMention]], + entity_type: str, + entity_id: Optional[str] = None, + confidence: float = 1., + update: bool = True) -> UUID: + added_mention_uuids = [] + for mention in mentions: + added_mention_uuid = self.add_entity_mention(mention=mention, update=update) + added_mention_uuids.append(added_mention_uuid) + + entity = Entity(uuid=augf.next(), + id=entity_id, + mentionIdList=added_mention_uuids, + rawMentionList=None, + type=entity_type, + confidence=confidence, + canonicalName=None) + self._entity_set.entityList.append(entity) + return entity.uuid + def add_entity_mention(self, mention: Union[EntityMention, CementEntityMention], update: bool = True) -> UUID: @@ -313,6 +336,29 @@ def add_entity_mention(self, return entity_mention.uuid + def add_raw_situation(self, + situation_type: str, + situation_kind: Optional[str] = None, + arguments: Optional[List[Argument]] = None, + mention_ids: Optional[List[UUID]] = None, + justifications: Optional[List[Justification]] = None, + time_ml: Optional[TimeML] = None, + intensity: Optional[float] = None, + polarity: Optional[str] = None, + confidence: float = 1.) -> UUID: + situation = Situation(uuid=augf.next(), + situationType=situation_type, + situationKind=situation_kind, + argumentList=arguments, + mentionIdList=mention_ids, + justificationList=justifications, + timeML=time_ml, + intensity=intensity, + polarity=polarity, + confidence=confidence) + self._situation_set.situationList.append(situation) + return situation.uuid + def add_situation_mention(self, mention: SituationMention, trigger: Optional[CementSpan] = None) -> UUID: # TODO(@Yunmo): verify this assumption? if trigger: @@ -521,17 +567,3 @@ def from_communication_file(cls, file_path: str, annotation_set: str = TOOL_NAME @classmethod def from_communication(cls, comm: Communication, annotation_set: str = TOOL_NAME) -> 'CementDocument': return cls(comm=comm, annotation_set=annotation_set) - - -if __name__ == '__main__': - # from transformers import BasicTokenizer - # tokenizer = BasicTokenizer() - import json - import numpy as np - - with open('out/downloadRAMS/Baseline.baseline/out/RAMS_1.0/data/train.jsonlines') as f: - json_doc = json.loads(next(f)) - doc = CementDocument.from_tokens(tokens={'paragraph': json_doc['sentences']}) - indices = global_to_local_indices(np.array([19, 20, 23, 24]), doc._tokenization_offsets) - doc[:] - pass diff --git a/scripts/oneie-to-concrete.py b/scripts/oneie-to-concrete.py new file mode 100644 index 0000000..fd6da81 --- /dev/null +++ b/scripts/oneie-to-concrete.py @@ -0,0 +1,185 @@ +import copy +import json +import logging +import os +from collections import defaultdict +from typing import * +import argparse + +from concrete import SituationMention +from tqdm import tqdm + +from cement.cement_document import CementDocument +from cement.cement_entity_mention import CementEntityMention +from cement.cement_span import CementSpan + +logger = logging.getLogger(__name__) + + +def read_json(input_path: str, use_dir: bool = False) -> Generator[Dict, None, None]: + if use_dir: + file_names: List[str] = os.listdir(input_path) + for fn in file_names: + if '.json' not in fn: + continue + with open(os.path.join(input_path, fn)) as f: + yield json.load(f) + else: + with open(input_path) as f: + for line in f: + yield json.loads(line) + + +def to_cement_doc_stream(json_stream: Iterable[Dict]) -> Iterable[CementDocument]: + entity_counter = Counter() + entity_mention_counter = Counter() + event_counter = Counter() + event_mention_counter = Counter() + relation_mention_counter = Counter() + for json_obj in json_stream: + # create a `CementDocument` + doc = CementDocument.from_tokens(tokens={'passage': [sent_obj['tokens'] for sent_obj in json_obj['sentences']]}, + doc_id=json_obj['doc_id']) + + doc.write_kv_map(prefix='meta', key='ner-iterator', suffix='sentence', value='True') + doc.write_kv_map(prefix='meta', key='events-iterator', suffix='sentence', value='True') + doc.write_kv_map(prefix='meta', key='relations-iterator', suffix='sentence', value='True') + + entity_id_to_mentions: Dict[str, List[CementEntityMention]] = defaultdict(list) + event_id_to_mentions: Dict[str, List[SituationMention]] = defaultdict(list) + em_id_to_cem: Dict[str, CementEntityMention] = {} + sm_id_to_sm: Dict[str, SituationMention] = {} + for line_id, sent_obj in enumerate(json_obj['sentences']): + # extract entity mentions (EMD or NER) + if len(sent_obj['entities']) > 0: + uuids = [] + for em_obj in sent_obj['entities']: + start, end = doc.to_global_indices(sent_ids=[line_id], + indices=[em_obj['start'], em_obj['end'] - 1]) + cem = CementEntityMention(start=start, + end=end, + entity_type=f'{em_obj["entity_type"]}:{em_obj["entity_subtype"]}', + phrase_type=em_obj['mention_type'], + text=em_obj['text'], + document=doc) + em_id = doc.add_entity_mention(mention=cem) + em_id_to_cem[em_obj['mention_id']] = cem + uuids.append(em_id.uuidString) + entity_id_to_mentions[em_obj['entity_id']].append(cem) + entity_mention_counter[json_obj["doc_id"]] += 1 + doc.write_kv_map(prefix='ner', key=str(line_id), suffix='sentence', value=','.join(uuids)) + # else: + # logger.info(f'doc_key={json_obj["doc_id"]}, line_id={line_id} - does not have entities.') + + # extract event mentions + if len(sent_obj['events']) > 0: + uuids = [] + for event_mention_obj in sent_obj['events']: + trigger_start, trigger_end = doc.to_global_indices( + sent_ids=[line_id], + indices=[event_mention_obj['trigger']['start'], event_mention_obj['trigger']['end'] - 1] + ) + trigger = CementSpan(start=trigger_start, + end=trigger_end, + text=event_mention_obj['trigger']['text'], + document=doc) + arguments = [] + for arg_obj in event_mention_obj['arguments']: + mention = copy.deepcopy(em_id_to_cem[arg_obj['mention_id']]) + mention.attrs.add(k='role', v=arg_obj['role']) + arguments.append(mention) + + sm_id = doc.add_event_mention( + trigger=trigger, + arguments=arguments, + event_type=f'{event_mention_obj["event_type"]}:{event_mention_obj["event_subtype"]}' + + ) + event_mention = doc.comm.situationMentionForUUID[sm_id.uuidString] + sm_id_to_sm[event_mention_obj['mention_id']] = event_mention + event_id_to_mentions[event_mention_obj['event_id']].append(event_mention) + uuids.append(sm_id.uuidString) + event_mention_counter[json_obj["doc_id"]] += 1 + doc.write_kv_map(prefix='event', key=str(line_id), suffix='sentence', value=','.join(uuids)) + # else: + # logger.info(f'doc_key={json_obj["doc_id"]}, line_id={line_id} - does not have events.') + + # extract relation mentions + if len(sent_obj['relations']) > 0: + uuids = [] + for relation_mention_obj in sent_obj['relations']: + arguments = [] + for arg_obj in [relation_mention_obj['arg1'], relation_mention_obj['arg2']]: + mention = copy.deepcopy(em_id_to_cem[arg_obj['mention_id']]) + mention.attrs.add(k='role', v=arg_obj['role']) + arguments.append(mention) + + sm_id = doc.add_relation_mention( + arguments=arguments, + relation_type=f'{relation_mention_obj["relation_type"]}:' + f'{relation_mention_obj["relation_subtype"]}' + ) + relation_mention = doc.comm.situationMentionForUUID[sm_id.uuidString] + sm_id_to_sm[relation_mention_obj['relation_id']] = relation_mention + uuids.append(sm_id.uuidString) + relation_mention_counter[json_obj["doc_id"]] += 1 + doc.write_kv_map(prefix='relation', key=str(line_id), suffix='sentence', value=','.join(uuids)) + # else: + # logger.info(f'doc_key={json_obj["doc_id"]}, line_id={line_id} - does not have relations.') + + for entity_id, mentions in entity_id_to_mentions.items(): + doc.add_entity(mentions=mentions, + entity_type=mentions[0].attrs.entity_type, + entity_id=entity_id, + update=False) + entity_counter[json_obj["doc_id"]] += 1 + for event_id, mentions in event_id_to_mentions.items(): + doc.add_raw_situation(situation_type='EVENT', + situation_kind=mentions[0].situationKind, + mention_ids=[mention.uuid for mention in mentions]) + event_counter[json_obj["doc_id"]] += 1 + + logger.info( + f'{json_obj["doc_id"]} - #events={event_counter[json_obj["doc_id"]]}, ' + f'#event_mentions={event_mention_counter[json_obj["doc_id"]]}, ' + f'#entities={entity_counter[json_obj["doc_id"]]}, ' + f'#entity_mentions={entity_mention_counter[json_obj["doc_id"]]}, ' + f'#relation_mentions={relation_mention_counter[json_obj["doc_id"]]}' + ) + + yield doc + + logger.info( + f'Total - #events={sum(event_counter.values())}, ' + f'#event_mentions={sum(event_mention_counter.values())}, ' + f'#entities={sum(entity_counter.values())}, ' + f'#entity_mentions={sum(entity_mention_counter.values())}, ' + f'#relation_mentions={sum(relation_mention_counter.values())}' + ) + + +def serialize_doc(doc_stream: Iterable[CementDocument], base_path: str) -> NoReturn: + for doc in tqdm(doc_stream): + doc.to_communication_file(file_path=os.path.join(base_path, f'{doc.comm.id}.concrete')) + + +if __name__ == '__main__': + logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO) + + parser = argparse.ArgumentParser() + parser.add_argument('--input-path', type=str) + parser.add_argument('--output-path', type=str) + parser.add_argument('--use-dir', action='store_true') + parser.add_argument('--show-cement-warnings', action='store_true') + args = parser.parse_args() + + if args.show_cement_warnings: + logging.getLogger('cement.cement_document').setLevel(logging.WARNING) + else: + logging.getLogger('cement.cement_document').setLevel(logging.CRITICAL) + + serialize_doc(doc_stream=to_cement_doc_stream(json_stream=read_json(input_path=args.input_path, + use_dir=args.use_dir)), + base_path=args.output_path)