Skip to content

Commit

Permalink
Merge pull request #373 from CogStack/CU2e77a5x-cdb-merge-function
Browse files Browse the repository at this point in the history
CU2e77a5x - Add a CDB merge function

Given two CDBs, a new CDB will be created combining the entries from both CDBs.
  • Loading branch information
adam-sutton-1992 authored Dec 18, 2023
2 parents 45cef2b + c74fe1f commit 70305f4
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 0 deletions.
117 changes: 117 additions & 0 deletions medcat/utils/cdb_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import logging
import numpy as np

from copy import deepcopy
from medcat.cdb import CDB

logger = logging.getLogger(__name__) # separate logger from the package-level one


def merge_cdb(cdb1: "CDB",
cdb2: "CDB",
overwrite_training: int = 0,
full_build: bool = False):
"""Merge two CDB's together to produce a new, single CDB. The contents of inputs CDBs will not be changed.
`addl_info` can not be perfectly merged, and will prioritise cdb1. see `full_build`
Args:
cdb1 (medcat.cdb.CDB):
The first medcat cdb to merge. In cases where merging isn't suitable isn't ideal (such as
cui2preferred_name), this cdb values will be prioritised over cdb2.
cdb2 (medcat.cdb.CDB):
The second medcat cdb to merge.
overwrite_training (int):
Choose to prioritise a CDB's context vectors values over merging gracefully. 0 - no prio, 1 - CDB1, 2 - CDB2
full_build (bool):
Add additional information from "addl_info" dicts "cui2ontologies" and "cui2description"
"""
config = deepcopy(cdb1.config)
cdb = CDB(config)

# Copy CDB 1 - as all settings from CDB 1 will be carried over
cdb.cui2names = deepcopy(cdb1.cui2names)
cdb.cui2snames = deepcopy(cdb1.cui2snames)
cdb.cui2count_train = deepcopy(cdb1.cui2count_train)
cdb.cui2info = deepcopy(cdb1.cui2info)
cdb.cui2context_vectors = deepcopy(cdb1.cui2context_vectors)
cdb.cui2tags = deepcopy(cdb1.cui2tags)
cdb.cui2type_ids = deepcopy(cdb1.cui2type_ids)
cdb.cui2preferred_name = deepcopy(cdb1.cui2preferred_name)
cdb.name2cuis = deepcopy(cdb1.name2cuis)
cdb.name2cuis2status = deepcopy(cdb1.name2cuis2status)
cdb.name2count_train = deepcopy(cdb1.name2count_train)
cdb.name_isupper = deepcopy(cdb1.name_isupper)
if full_build:
cdb.addl_info = deepcopy(cdb1.addl_info)

# handles cui2names, cui2snames, name_isupper, name2cuis, name2cuis2status, cui2preferred_name
for cui in cdb2.cui2names:
names = dict()
for name in cdb2.cui2names[cui]:
names[name] = {'snames': cdb2.cui2snames.get(cui, set()), 'is_upper': cdb2.name_isupper.get(name, False), 'tokens': {}, 'raw_name': cdb2.get_name(cui)}
name_status = cdb2.name2cuis2status.get(name, 'A').get(cui, 'A') # get the name status if it exists, default to 'A'
# For addl_info check cui2original_names as they MUST be added
ontologies = set()
description = ''
to_build = False
if full_build and (cui in cdb2.addl_info['cui2original_names'] or cui in cdb2.addl_info['cui2description']):
to_build = True
if 'cui2ontologies' in cdb2.addl_info:
ontologies.update(cdb2.addl_info['cui2ontologies'][cui])
if 'cui2description' in cdb2.addl_info:
description = cdb2.addl_info['cui2description'][cui]
cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status,
type_ids=cdb2.cui2type_ids[cui], description=description, full_build=to_build)
if cui in cdb1.cui2names:
if (cui in cdb1.cui2count_train or cui in cdb2.cui2count_train) and not (overwrite_training == 1 and cui in cdb1.cui2count_train):
if overwrite_training == 2 and cui in cdb2.cui2count_train:
cdb.cui2count_train[cui] = cdb2.cui2count_train[cui]
else:
cdb.cui2count_train[cui] = cdb1.cui2count_train.get(cui, 0) + cdb2.cui2count_train.get(cui, 0)
if cui in cdb1.cui2context_vectors and not (overwrite_training == 1 and cui in cdb1.cui2context_vectors[cui]):
if overwrite_training == 2 and cui in cdb2.cui2context_vectors:
weights = [0, 1]
else:
norm = cdb.cui2count_train[cui]
weights = [np.divide(cdb1.cui2count_train.get(cui, 0), norm), np.divide(cdb2.cui2count_train.get(cui, 0), norm)]
contexts = set(list(cdb1.cui2context_vectors.get(cui, {}).keys()) + list(cdb2.cui2context_vectors.get(cui, {}).keys())) # xlong, long, medium, short
for s in contexts:
cdb.cui2context_vectors[cui][s] = (weights[0] * cdb1.cui2context_vectors[cui].get(s, np.zeros(shape=(300)))) + (weights[1] * cdb2.cui2context_vectors[cui].get(s, np.zeros(shape=(300))))
if cui in cdb1.cui2tags:
cdb.cui2tags[cui].append(cdb2.cui2tags[cui])
if cui in cdb1.cui2type_ids:
cdb.cui2type_ids[cui] = cdb1.cui2type_ids[cui].union(cdb2.cui2type_ids[cui])
else:
if cui in cdb2.cui2count_train:
cdb.cui2count_train[cui] = cdb2.cui2names[cui]
if cui in cdb2.cui2info:
cdb.cui2info[cui] = cdb2.cui2info[cui]
if cui in cdb2.cui2context_vectors:
cdb.cui2context_vectors[cui] = cdb2.cui2context_vectors[cui]
if cui in cdb2.cui2tags:
cdb.cui2tags[cui] = cdb2.cui2tags[cui]
if cui in cdb2.cui2type_ids:
cdb.cui2type_ids[cui] = cdb2.cui2type_ids[cui]

if overwrite_training != 1:
for name in cdb2.name2cuis:
if name in cdb1.name2cuis and overwrite_training == 0: # if they exist in both cdbs
if name in cdb1.name2count_train and name in cdb2.name2count_train:
cdb.name2count_train[name] = str(int(cdb1.name2count_train[name]) + int(cdb2.name2count_train[name])) # these are strings for some reason
else:
if name in cdb2.name2count_train:
cdb.name2count_train[name] = cdb2.name2count_train[name]

# snames
cdb.snames = cdb1.snames.union(cdb2.snames)

# vocab, adding counts if they occur in both
cdb.vocab = deepcopy(cdb1.vocab)
if overwrite_training != 1:
for word in cdb2.vocab:
if word in cdb.vocab and overwrite_training == 0:
cdb.vocab[word] += cdb2.vocab[word]
else:
cdb.vocab[word] = cdb2.vocab[word]

return cdb
35 changes: 35 additions & 0 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np

from medcat.vocab import Vocab
from medcat.cdb_maker import CDBMaker
from medcat.config import Config


class AsyncMock(unittest.mock.MagicMock):
Expand Down Expand Up @@ -86,3 +88,36 @@ def check_or_download(self):
return
with open(self.vocab_path, 'wb') as f:
f.write(tmp.content)


class ForCDBMerging:

def __init__(self) -> None:
# generating cdbs - two maker are requested as they point to the same created CDB.
config = Config()
config.general["spacy_model"] = "en_core_web_md"
maker1 = CDBMaker(config)
maker2 = CDBMaker(config) # second maker is required as it will otherwise point to same object
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "model_creator", "umls_sample.csv")
self.cdb1 = maker1.prepare_csvs(csv_paths=[path])
self.cdb2 = maker2.prepare_csvs(csv_paths=[path])

# generating context vectors here for for testing the weighted average function (based off cui2count_train)
zeroes = np.zeros(shape=(1,300))
ones = np.ones(shape=(1,300))
for i, cui in enumerate(self.cdb1.cui2names):
self.cdb1.cui2context_vectors[cui] = {"short": ones}
self.cdb2.cui2context_vectors[cui] = {"short": zeroes}
self.cdb1.cui2count_train[cui] = 1
self.cdb2.cui2count_train[cui] = i + 1
# adding new names and cuis to each cdb to test after merging
test_add = {"test": {'tokens': "test_token", 'snames': ["test_name"], 'raw_name': "test_raw_name", "is_upper": "P"}}
self.cdb1.add_names("C0006826", test_add)
unique_test = {"test": {'tokens': "test_token", 'snames': ["test_name"], 'raw_name': "test_raw_name", "is_upper": "P"}}
self.cdb2.add_names("UniqueTest", unique_test)
self.cdb2.cui2context_vectors["UniqueTest"] = {"short": zeroes}
self.cdb2.addl_info["cui2ontologies"] = {}
self.cdb2.addl_info["cui2description"] = {}
for cui in self.cdb2.cui2names:
self.cdb2.addl_info["cui2ontologies"][cui] = {"test_ontology"}
self.cdb2.addl_info["cui2description"][cui] = "test_description"
43 changes: 43 additions & 0 deletions tests/utils/test_cdb_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import unittest
import numpy as np
from tests.helper import ForCDBMerging
from medcat.utils.cdb_utils import merge_cdb


class CDBMergeTests(unittest.TestCase):

@classmethod
def setUpClass(cls):
to_merge = ForCDBMerging()
cls.cdb1 = to_merge.cdb1
cls.cdb2 = to_merge.cdb2
cls.merged_cdb = merge_cdb(cdb1=cls.cdb1, cdb2=cls.cdb2)
cls.overwrite_cdb = merge_cdb(cdb1=cls.cdb1, cdb2=cls.cdb2, overwrite_training=2, full_build=True)
cls.zeroes = np.zeros(shape=(1,300))
cls.ones = np.ones(shape=(1,300))

def test_merge_inserts(self):
self.assertIn("test", self.merged_cdb.cui2names["C0006826"])
self.assertIn("test_name", self.merged_cdb.cui2snames["C0006826"])
self.assertEqual("Cancer", self.merged_cdb.cui2preferred_name["C0006826"])

def test_no_full_build(self):
self.assertEqual(self.merged_cdb.addl_info["cui2ontologies"], dict())
self.assertEqual(self.merged_cdb.addl_info["cui2ontologies"], dict())

def test_full_build(self):
for cui in self.cdb2.cui2names:
self.assertEqual(self.overwrite_cdb.addl_info["cui2ontologies"][cui], {"test_ontology"})
self.assertEqual(self.overwrite_cdb.addl_info["cui2description"][cui], "test_description")

def test_vector_merge(self):
self.assertTrue(np.array_equal(self.zeroes, self.merged_cdb.cui2context_vectors["UniqueTest"]["short"]))
for i, cui in enumerate(self.cdb1.cui2names):
self.assertTrue(np.array_equal(self.merged_cdb.cui2context_vectors[cui]["short"], np.divide(self.ones, i+2)))


def test_overwrite_parameter(self):
for cui in self.cdb2.cui2names:
self.assertTrue(np.array_equal(self.overwrite_cdb.cui2context_vectors[cui]["short"], self.zeroes))
self.assertEqual(self.overwrite_cdb.addl_info["cui2ontologies"][cui], {"test_ontology"})
self.assertEqual(self.overwrite_cdb.addl_info["cui2description"][cui], "test_description")

0 comments on commit 70305f4

Please sign in to comment.