Skip to content

Commit

Permalink
added a flag to include classifying each sentence in spaCy.
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidFromPandora committed Apr 3, 2022
1 parent 333ef06 commit f9c19fc
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 113 deletions.
58 changes: 30 additions & 28 deletions classy_classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,24 @@
from spacy.language import Language
from spacy.tokens import Doc

from .classifiers.sentence_transformer import \
classySentenceTransformer as classyClassifier
from .classifiers.sentence_transformer import classySentenceTransformer as classyClassifier
from .classifiers.spacy_few_shot_external import classySpacyFewShotExternal
from .classifiers.spacy_internal import classySpacyInternal
from .classifiers.spacy_zero_shot_external import classySpacyZeroShotExternal

__all__ = [
'classyClassifier',
'classySpacyFewShotExternal',
'classySpacyZeroShotExternal',
'classySpacyInternal'
]
__all__ = ["classyClassifier", "classySpacyFewShotExternal", "classySpacyZeroShotExternal", "classySpacyInternal"]


@Language.factory(
"text_categorizer",
default_config={
"data": None,
"model": None,
"device": "cpu",
"config": {
"C": [1, 2, 5, 10, 20, 100],
"kernels": ["linear"],
"max_cross_validation_folds": 5
},
"cat_type": 'few'
"config": {"C": [1, 2, 5, 10, 20, 100], "kernels": ["linear"], "max_cross_validation_folds": 5},
"cat_type": "few",
"include_doc": True,
"include_sent": False,
},
)
def make_text_categorizer(
Expand All @@ -39,48 +32,57 @@ def make_text_categorizer(
device: str,
config: dict,
model: str = None,
cat_type: str = 'few',
):
if model == 'spacy':
if cat_type == 'zero':
raise NotImplementedError('cannot use spacy internal embeddings with zero-shot classification')
cat_type: str = "few",
include_doc: bool = True,
include_sent: bool = False,
):
if model == "spacy":
if cat_type == "zero":
raise NotImplementedError("cannot use spacy internal embeddings with zero-shot classification")
return classySpacyInternal(
nlp=nlp,
name=name,
data=data,
config=config,
nlp=nlp, name=name, data=data, config=config, include_doc=include_doc, include_sent=include_sent
)
else:
if cat_type == 'zero':
if cat_type == "zero":
if model:
return classySpacyZeroShotExternal(
nlp=nlp,
name=name,
data=data,
device=device,
model=model
model=model,
include_doc=include_doc,
include_sent=include_sent,
)
else:
return classySpacyZeroShotExternal(
nlp=nlp,
name=name,
data=data,
device=device,
model=model
model=model,
include_doc=include_doc,
include_sent=include_sent,
)
else:
if model:
return classySpacyFewShotExternal(
nlp=nlp,
name=name,
data=data,
device=device,
model=model,
config=config,
include_doc=include_doc,
include_sent=include_sent,
)
else:
return classySpacyFewShotExternal(
nlp=nlp,
name=name,
data=data,
device=device,
config=config,
include_doc=include_doc,
include_sent=include_sent,
)


47 changes: 32 additions & 15 deletions classy_classification/classifiers/spacy_few_shot_external.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import os

from spacy import util
from spacy.language import Language
from spacy.tokens import Doc
from spacy.tokens import Doc, Span

from .sentence_transformer import classySentenceTransformer


class classySpacyFewShotExternal(classySentenceTransformer):
def __init__(self, name, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, nlp, name, data, device, config, include_doc, include_sent, *args, **kwargs):
super().__init__(data=data, device=device, config=config, *args, **kwargs)
self.name = name
Doc.set_extension("cats", default=None, force=True)
self.include_doc = include_doc
self.include_sent = include_sent
if include_sent:
Span.set_extension("cats", default=None, force=True)
if "sentencizer" not in nlp.pipe_names:
nlp.add_pipe("sentencizer")
if include_doc:
Doc.set_extension("cats", default=None, force=True)

def __call__(self, doc: Doc) -> Doc:
"""
Expand All @@ -23,8 +27,11 @@ def __call__(self, doc: Doc) -> Doc:
Returns:
Doc: spacy doc with ._.cats key-class proba-value dict
"""
pred_result = super(self.__class__, self).__call__(doc.text.replace("\n", " "))
doc._.cats = pred_result
if self.include_doc:
pred_result = super(self.__class__, self).__call__(doc.text.replace("\n", " "))
doc._.cats = pred_result
if self.include_sent:
doc = self.set_pred_results_for_doc(doc)

return doc

Expand All @@ -39,11 +46,21 @@ def pipe(self, stream, batch_size=128):
Doc: spacy doc with ._.cats key-class proba-value dict
"""
for docs in util.minibatch(stream, size=batch_size):
texts = [doc.text.replace("\n", " ") for doc in docs]
pred_results = super(self.__class__, self).pipe(texts)

pred_results = [doc.text.replace("\n", " ") for doc in docs]

if self.include_doc:
pred_results = super(self.__class__, self).pipe(pred_results)

for doc, pred_result in zip(docs, pred_results):
doc._.cats = pred_result

if self.include_doc:
doc._.cats = pred_result
if self.include_sent:
doc = self.set_pred_results_for_doc(doc)
yield doc


def set_pred_results_for_doc(self, doc: Doc):
pred_results = super(self.__class__, self).pipe([sent.text for sent in list(doc.sents)])
for sent, pred in zip(doc.sents, pred_results):
sent._.cats = pred

return doc
52 changes: 36 additions & 16 deletions classy_classification/classifiers/spacy_internal.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
from typing import List

from spacy import util
from spacy.tokens import Doc
from spacy.tokens import Doc, Span

from .classy_skeleton import classySkeleton


class classySpacyInternal(classySkeleton):
def __init__(self, nlp, name, *args, **kwargs):
super().__init__(*args, **kwargs)
Doc.set_extension("cats", default=None, force=True)
def __init__(self, nlp, name, data, config, include_doc, include_sent):
super().__init__(data=data, config=config)
self.include_doc = include_doc
self.include_sent = include_sent
if include_sent:
Span.set_extension("cats", default=None, force=True)
if "sentencizer" not in nlp.pipe_names:
nlp.add_pipe("sentencizer")
if include_doc:
Doc.set_extension("cats", default=None, force=True)
self.name = name
self.nlp = nlp
self.set_training_data()
self.set_svc()

def get_embeddings(self, text: List[str]) -> List[float]:
""" Retrieve embeddings from text.
"""Retrieve embeddings from text.
Overwrites function from the classySkeleton that is used to get embeddings for training data.
Args:
Expand All @@ -27,9 +34,9 @@ def get_embeddings(self, text: List[str]) -> List[float]:
"""
docs = self.nlp.pipe(text)
embeddings = [self.get_embeddings_from_doc(doc) for doc in docs]

return embeddings

def get_embeddings_from_doc(self, doc: Doc) -> List[float]:
"""Retrieve a vector from a spacy doc and internal embeddings.
Expand Down Expand Up @@ -59,9 +66,12 @@ def __call__(self, doc: Doc):
Returns:
Doc: spacy doc with ._.cats key-class proba-value dict
"""
embeddings = self.get_embeddings_from_doc(doc)
embeddings = embeddings.reshape(1, -1)
doc._.cats = self.get_prediction(embeddings)[0]
if self.include_doc:
embeddings = self.get_embeddings_from_doc(doc)
embeddings = embeddings.reshape(1, -1)
doc._.cats = self.get_prediction(embeddings)[0]
if self.include_sent:
doc = self.set_pred_results_for_doc(doc)

return doc

Expand All @@ -76,11 +86,21 @@ def pipe(self, stream, batch_size=128):
Doc: spacy doc with ._.cats key-class proba-value dict
"""
for docs in util.minibatch(stream, size=batch_size):
embeddings = [self.get_embeddings_from_doc(doc) for doc in docs]
pred_results = self.get_prediction(embeddings)

pred_results = [self.get_embeddings_from_doc(doc) for doc in docs]
if self.include_doc:
pred_results = self.get_prediction(pred_results)

for doc, pred_result in zip(docs, pred_results):
doc._.cats = pred_result

if self.include_doc:
doc._.cats = pred_result
if self.include_sent:
doc = self.set_pred_results_for_doc(doc)

yield doc


def set_pred_results_for_doc(self, doc: Doc):
embeddings = [sent.as_doc().vector for sent in list(doc.sents)]
pred_results = self.get_prediction(embeddings)
for sent, pred in zip(doc.sents, pred_results):
sent._.cats = pred
return doc
Loading

0 comments on commit f9c19fc

Please sign in to comment.