From 82bafca0c7dbed3305ee4c28c6a8ac086705a63e Mon Sep 17 00:00:00 2001 From: Peter <74869040+pszemraj@users.noreply.github.com> Date: Sun, 18 Feb 2024 20:18:19 +0100 Subject: [PATCH] Batch processing (#13) this PR contains some more behind-the-scenes improvements related to ease-of-use and/or batch processing for the `Summarizer` class object: - disable the progress bar for within-loop summarization of a single long string - add a 'smart' `__call__` function that hands off to the text and filepath processing fns - small improvements/updates to docs --------- Signed-off-by: peter szemraj --- CHANGELOG.md | 6 ++ README.md | 5 +- src/textsum/__init__.py | 8 +- src/textsum/app.py | 26 +++--- src/textsum/cli.py | 10 ++- src/textsum/pdf2text.py | 18 ++-- src/textsum/summarize.py | 188 +++++++++++++++++++++++++++------------ src/textsum/utils.py | 4 +- 8 files changed, 168 insertions(+), 97 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8472253..1ab576b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file. Dates are d Generated by [`auto-changelog`](https://github.com/CookPete/auto-changelog). +#### [v0.2.0](https://github.com/pszemraj/textsum/compare/v0.1.5...v0.2.0) + +> 8 July 2023 + +- Draft: support faster inference methods [`#8`](https://github.com/pszemraj/textsum/pull/8) + #### [v0.1.5](https://github.com/pszemraj/textsum/compare/v0.1.3...v0.1.5) > 31 January 2023 diff --git a/README.md b/README.md index b161b8b..7899baa 100644 --- a/README.md +++ b/README.md @@ -268,8 +268,6 @@ summarizer = Summarizer(load_in_8bit=True) If using the python API, it's better to initiate tf32 yourself; see [here](https://huggingface.co/docs/transformers/perf_train_gpu_one#tf32) for how. -Here are some suggestions for additions to the README in order to reflect the latest changes in the `__init__` method of your `Summarizer` class: - ### Using Optimum ONNX Runtime > ⚠️ **Note:** This feature is experimental and might not work as expected. Use at your own risk. ⚠️🧪 @@ -324,7 +322,8 @@ See the [CONTRIBUTING.md](CONTRIBUTING.md) file for details on how to contribute - [x] LLM.int8 inference - [x] optimum inference integration - [ ] better documentation [in the wiki](https://github.com/pszemraj/textsum/wiki), details on improving performance (speed, quality, memory usage, etc.) -- [ ] improvements to the PDF OCR helper module + - [x] in-progress +- [ ] improvements to the PDF OCR helper module (_TBD - may focus more on being a summarization tool_) _Other ideas? Open an issue or PR!_ diff --git a/src/textsum/__init__.py b/src/textsum/__init__.py index cf08060..86c24e3 100644 --- a/src/textsum/__init__.py +++ b/src/textsum/__init__.py @@ -4,13 +4,13 @@ """ import sys -from . import summarize, utils - if sys.version_info[:2] >= (3, 8): # Import directly (no need for conditional) when `python_requires = >= 3.8` - from importlib.metadata import PackageNotFoundError, version # pragma: no cover + from importlib.metadata import PackageNotFoundError # pragma: no cover + from importlib.metadata import version else: - from importlib_metadata import PackageNotFoundError, version # pragma: no cover + from importlib_metadata import PackageNotFoundError # pragma: no cover + from importlib_metadata import version try: # Change here if project is renamed and does not equal the package name diff --git a/src/textsum/app.py b/src/textsum/app.py index b3a566f..c0a9fc4 100644 --- a/src/textsum/app.py +++ b/src/textsum/app.py @@ -1,23 +1,14 @@ """ app.py - a module to run the text summarization app (gradio interface) """ + import contextlib import logging import os -import random import re import time from pathlib import Path -os.environ["USE_TORCH"] = "1" -os.environ["DEMO_MAX_INPUT_WORDS"] = "2048" # number of words to truncate input to -os.environ["DEMO_MAX_INPUT_PAGES"] = "20" # number of pages to truncate PDFs to -os.environ["TOKENIZERS_PARALLELISM"] = "false" # parallelism is buggy with gradio - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) - import gradio as gr import nltk from cleantext import clean @@ -25,8 +16,16 @@ from textsum.pdf2text import convert_PDF_to_Text from textsum.summarize import Summarizer -from textsum.utils import truncate_word_count, get_timestamp +from textsum.utils import get_timestamp, truncate_word_count +os.environ["USE_TORCH"] = "1" +os.environ["DEMO_MAX_INPUT_WORDS"] = "2048" # number of words to truncate input to +os.environ["DEMO_MAX_INPUT_PAGES"] = "20" # number of pages to truncate PDFs to +os.environ["TOKENIZERS_PARALLELISM"] = "false" # parallelism is buggy with gradio + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) _here = Path.cwd() nltk.download("stopwords") # TODO=find where this requirement originates from @@ -214,7 +213,6 @@ def main(): demo = gr.Blocks() with demo: - gr.Markdown("# Summarization UI with `textsum`") gr.Markdown( f""" @@ -224,21 +222,18 @@ def main(): """ ) with gr.Column(): - gr.Markdown("## Load Inputs & Select Parameters") gr.Markdown( "Enter text below in the text area. The text will be summarized [using the selected parameters](https://huggingface.co/blog/how-to-generate). Optionally load an example below or upload a file. (`.txt` or `.pdf` - _[link to guide](https://i.imgur.com/c6Cs9ly.png)_)" ) with gr.Row(variant="compact"): with gr.Column(scale=0.5, variant="compact"): - num_beams = gr.Radio( choices=[2, 3, 4], label="Beam Search: # of Beams", value=2, ) with gr.Column(variant="compact"): - uploaded_file = gr.File( label="File Upload", file_count="single", @@ -251,7 +246,6 @@ def main(): placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)", ) with gr.Column(min_width=100, scale=0.5): - load_file_button = gr.Button("Upload File") with gr.Column(): diff --git a/src/textsum/cli.py b/src/textsum/cli.py index 1a28128..faa58a8 100644 --- a/src/textsum/cli.py +++ b/src/textsum/cli.py @@ -4,6 +4,7 @@ Usage: textsum-dir --help """ + import logging import pprint as pp import random @@ -31,7 +32,7 @@ def main( batch_length: int = 4096, batch_stride: int = 16, num_beams: int = 4, - length_penalty: float = 0.8, + length_penalty: float = 1.0, repetition_penalty: float = 2.5, max_length_ratio: float = 0.25, min_length: int = 8, @@ -44,6 +45,7 @@ def main( logfile: Optional[str] = None, file_extension: str = "txt", skip_completed: bool = False, + disable_progress_bar: bool = False, ): """ Main function to summarize text files in a directory. @@ -61,7 +63,7 @@ def main( batch_length (int, optional): The length of each batch. Default: 4096. batch_stride (int, optional): The stride of each batch. Default: 16. num_beams (int, optional): The number of beams to use for beam search. Default: 4. - length_penalty (float, optional): The length penalty to use for decoding. Default: 0.8. + length_penalty (float, optional): The length penalty to use for decoding. Default: 1.0. repetition_penalty (float, optional): The repetition penalty to use for beam search. Default: 2.5. max_length_ratio (float, optional): The maximum length of the summary as a ratio of the batch length. Default: 0.25. min_length (int, optional): The minimum length of the summary. Default: 8. @@ -74,6 +76,7 @@ def main( logfile (str, optional): Path to the log file. This will set loglevel to INFO (if not set) and write to the file. file_extension (str, optional): The file extension to use when searching for input files., defaults to "txt" skip_completed (bool, optional): Skip files that have already been summarized. Default: False. + disable_progress_bar (bool, optional): Disable the progress bar for intra-file summarization batches. Default: False. Returns: None @@ -107,6 +110,7 @@ def main( compile_model=compile, optimum_onnx=optimum_onnx, force_cache=force_cache, + disable_progress_bar=disable_progress_bar, **params, ) summarizer.print_config() @@ -142,7 +146,7 @@ def main( failed_files.append(f) if isinstance(e, RuntimeError): # if a runtime error occurs, exit immediately - logging.error("Not continuing summarization due to runtime error") + logging.error("Stopping summarization: runtime error") failed_files.extend(input_files[input_files.index(f) + 1 :]) break diff --git a/src/textsum/pdf2text.py b/src/textsum/pdf2text.py index a967a03..2c303d1 100644 --- a/src/textsum/pdf2text.py +++ b/src/textsum/pdf2text.py @@ -7,15 +7,6 @@ """ import logging -from pathlib import Path - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)s %(message)s", - datefmt="%m/%d/%Y %I:%M:%S", -) - - import os import re import shutil @@ -29,6 +20,12 @@ from doctr.models import ocr_predictor from spellchecker import SpellChecker +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", + datefmt="%m/%d/%Y %I:%M:%S", +) + def simple_rename(filepath, target_ext=".txt"): _fp = Path(filepath) @@ -136,7 +133,6 @@ def clean_OCR(ugly_text: str): def move2completed(from_dir, filename, new_folder="completed", verbose=False): - # this is the better version old_filepath = join(from_dir, filename) @@ -275,7 +271,6 @@ def cleantxt_ocr(ugly_text, lower=False, lang: str = "en") -> str: def format_ocr_out(OCR_data): - if isinstance(OCR_data, list): text = " ".join(OCR_data) else: @@ -322,7 +317,6 @@ def convert_PDF_to_Text( ocr_model=None, max_pages: int = 20, ): - st = time.perf_counter() PDF_file = Path(PDF_file) ocr_model = ocr_predictor(pretrained=True) if ocr_model is None else ocr_model diff --git a/src/textsum/summarize.py b/src/textsum/summarize.py index a10487a..a0b4211 100644 --- a/src/textsum/summarize.py +++ b/src/textsum/summarize.py @@ -1,12 +1,14 @@ """ summarize.py - a module that contains functions for summarizing text """ + import json import logging +import pprint as pp import sys import warnings -import pprint as pp from pathlib import Path +from typing import Union import torch from cleantext import clean @@ -52,9 +54,10 @@ def __init__( compile_model: bool = False, optimum_onnx: bool = False, force_cache: bool = False, + disable_progress_bar: bool = False, **kwargs, ): - """ + f""" __init__ - initialize the Summarizer class :param str model_name_or_path: the name or path of the model to load, defaults to "pszemraj/long-t5-tglobal-base-16384-book-summary" @@ -67,12 +70,16 @@ def __init__( :param bool compile_model: whether to compile the model (pytorch 2.0+ only), defaults to False :param bool optimum_onnx: whether to load the model in ONNX Runtime, defaults to False :param bool force_cache: whether to force the model to use cache, defaults to False - :param kwargs: additional keyword arguments to pass to the model as inference parameters + :param bool disable_progress_bar: whether to disable the progress bar, defaults to False + :param kwargs: additional keyword arguments to pass to the model as inference parameters, any of: {self.settable_inference_params} """ self.logger = logging.getLogger(__name__) - - self.model_name_or_path = model_name_or_path self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu" + self.disable_progress_bar = disable_progress_bar + self.force_cache = force_cache + self.is_general_attention_model = is_general_attention_model + self.model_name_or_path = model_name_or_path + self.use_cuda = use_cuda self.logger.debug(f"loading model {model_name_or_path} to {self.device}") if load_in_8bit: @@ -88,8 +95,8 @@ def __init__( device_map="auto", ) elif optimum_onnx: - from optimum.onnxruntime import ORTModelForSeq2SeqLM import onnxruntime + from optimum.onnxruntime import ORTModelForSeq2SeqLM if self.device == "cuda": self.logger.warning( @@ -153,7 +160,7 @@ def __init__( "repetition_penalty": 2.5, "num_beams": 4, "num_beam_groups": 1, - "length_penalty": 0.8, + "length_penalty": 1.0, "early_stopping": True, "do_sample": False, } # default inference parameters @@ -181,19 +188,22 @@ def __init__( "textsum_version": textsum.__version__, } + def __str__(self): + return f"Summarizer({json.dumps(self.config)})" + def __repr__(self): - return f"Summarizer(model_name_or_path={self.model_name_or_path}, use_cuda={self.use_cuda}, token_batch_length={self.token_batch_length}, batch_stride={self.batch_stride}, max_length_ratio={self.max_length_ratio}, load_in_8bit={self.load_in_8bit}, compile_model={self.compile_model}, optimum_onnx={self.optimum_onnx})" + return self.__str__() def set_inference_params( self, new_params: dict = None, - config_file: str or Path = None, + config_file: Union[str, Path] = None, ): """ set_inference_params - update the inference parameters to use when summarizing text :param dict new_params: a dictionary of new inference parameters to use, defaults to None - :param str or Path config_file: a path to a json file containing inference parameters, defaults to None + :param Union[str, Path] config_file: a path to a json file containing inference parameters, defaults to None NOTE: if both new_params and config_file are provided, entries in the config_file will overwrite entries in new_params if they have the same key """ @@ -232,7 +242,7 @@ def print_config(self): """print the current configuration""" print(json.dumps(self.config, indent=2)) - def save_config(self, path: str or Path = "textsum_config.json"): + def save_config(self, path: Union[str, Path] = "textsum_config.json"): """save the current configuration to a json file""" with open(path, "w", encoding="utf-8") as f: json.dump(self.config, f, indent=2) @@ -241,12 +251,13 @@ def update_loglevel(self, loglevel: int = logging.INFO): """update the loglevel of the logger""" self.logger.setLevel(loglevel) - def summarize_and_score(self, ids, mask, **kwargs): + def summarize_and_score(self, ids, mask, autocast_enabled: bool = False, **kwargs): """ - summarize_and_score - summarize a batch of text and return the summary and output scores + summarize_and_score - run inference on a batch of ids with the given attention mask :param ids: the token ids of the tokenized batch to summarize :param mask: the attention mask of the tokenized batch to summarize + :param bool autocast_enabled: whether to use autocast for inference :return tuple: a tuple containing the summary and output scores """ @@ -260,27 +271,27 @@ def summarize_and_score(self, ids, mask, **kwargs): # put global attention on token global_attention_mask[:, 0] = 1 - self.logger.debug( - f"generating summary for batch of size {input_ids.shape} with {kwargs}" - ) - if self.is_general_attention_model: - summary_pred_ids = self.model.generate( - input_ids, - attention_mask=attention_mask, - output_scores=True, - return_dict_in_generate=True, - **kwargs, - ) - else: - # this is for LED etc. - summary_pred_ids = self.model.generate( - input_ids, - attention_mask=attention_mask, - global_attention_mask=global_attention_mask, - output_scores=True, - return_dict_in_generate=True, - **kwargs, - ) + self.logger.debug(f"gen. summary batch, size {input_ids.shape} with {kwargs}") + with torch.autocast(device_type=self.device, enabled=autocast_enabled): + if self.is_general_attention_model: + summary_pred_ids = self.model.generate( + input_ids, + attention_mask=attention_mask, + output_scores=True, + return_dict_in_generate=True, + **kwargs, + ) + else: + # this is for LED etc. + summary_pred_ids = self.model.generate( + input_ids, + attention_mask=attention_mask, + global_attention_mask=global_attention_mask, + output_scores=True, + return_dict_in_generate=True, + **kwargs, + ) + summary = self.tokenizer.batch_decode( summary_pred_ids.sequences, skip_special_tokens=True, @@ -298,6 +309,7 @@ def summarize_via_tokenbatches( batch_stride: int = None, min_batch_length: int = 512, pad_incomplete_batch: bool = True, + disable_progress_bar: bool = None, **kwargs, ): """ @@ -306,12 +318,21 @@ def summarize_via_tokenbatches( :param str input_text: the text to summarize :param int batch_length: number of tokens to include in each input batch, default None (self.token_batch_length) :param int batch_stride: number of tokens to stride between batches, default None (self.token_batch_stride) + :param int min_batch_length: minimum number of tokens in a batch, default 512 :param bool pad_incomplete_batch: whether to pad the last batch to the length of the longest batch, default True + :param bool disable_progress_bar: whether to disable the progress bar, default None + :param kwargs: additional keyword arguments to pass to the summarize_and_score function + :return: a list of summaries, a list of scores, and a list of the input text for each batch """ batch_length = self.token_batch_length if batch_length is None else batch_length batch_stride = self.batch_stride if batch_stride is None else batch_stride + disable_progress_bar = ( + self.disable_progress_bar + if disable_progress_bar is None + else disable_progress_bar + ) if batch_length < min_batch_length: self.logger.warning( @@ -340,8 +361,11 @@ def summarize_via_tokenbatches( in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask gen_summaries = [] - pbar = tqdm(total=len(in_id_arr), desc="Generating Summaries") - + pbar = tqdm( + total=len(in_id_arr), + desc="Generating Summaries", + disable=disable_progress_bar, + ) for _id, _mask in zip(in_id_arr, att_arr): # If the batch is smaller than batch_length, pad it with the model's pad token if len(_id) < batch_length and pad_incomplete_batch: @@ -375,8 +399,9 @@ def summarize_via_tokenbatches( def save_summary( self, summary_data: dict, - target_file: str or Path = None, + target_file: Union[str, Path] = None, postprocess: bool = True, + batch_delimiter: str = "\n\n", custom_phrases: list = None, save_scores: bool = True, return_string: bool = False, @@ -385,8 +410,9 @@ def save_summary( save_summary - a function that takes the output of summarize_via_tokenbatches and saves it to a file after postprocessing :param dict summary_data: output of summarize_via_tokenbatches containing the summary and score for each batch - :param str or Path target_file: the file to save the summary to, defaults to None + :param Union[str, Path] target_file: the file to save the summary to, defaults to None :param bool postprocess: whether to postprocess the summary, defaults to True + :param str batch_delimiter: text delimiter between summary batches, defaults to "\n\n" :param list custom_phrases: a list of custom phrases to use in postprocessing, defaults to None :param bool save_scores: whether to save the scores for each batch, defaults to True :param bool return_string: whether to return the summary as a string, defaults to False @@ -410,7 +436,7 @@ def save_summary( sum_scores = [f"\n - {round(s['summary_score'],4)}" for s in summary_data] scores_text = "\n".join(sum_scores) - full_summary = "\n".join(sum_text) + full_summary = batch_delimiter.join(sum_text) if return_string: return full_summary @@ -448,6 +474,8 @@ def summarize_string( input_text: str, batch_length: int = None, batch_stride: int = None, + batch_delimiter: str = "\n\n", + disable_progress_bar: bool = None, **kwargs, ) -> str: """ @@ -456,6 +484,10 @@ def summarize_string( :param str input_text: the text to summarize :param int batch_length: number of tokens to use in each batch, defaults to None (self.token_batch_length) :param int batch_stride: number of tokens to stride between batches, defaults to None (self.batch_stride) + :param str batch_delimiter: text delimiter between summary batches, defaults to "\n\n" + :param bool disable_progress_bar: whether to disable the progress bar, defaults to None + :param kwargs: additional parameters to pass to summarize_via_tokenbatches + :return str: the summary """ @@ -475,56 +507,71 @@ def summarize_string( input_text, batch_length=batch_length, batch_stride=batch_stride, + disable_progress_bar=disable_progress_bar, **kwargs, ) - return self.save_summary(summary_data=gen_summaries, return_string=True) + return self.save_summary( + summary_data=gen_summaries, + return_string=True, + batch_delimiter=batch_delimiter, + ) def summarize_file( self, - file_path: str or Path, - output_dir: str or Path = None, - batch_length=None, - batch_stride=None, + file_path: Union[str, Path], + output_dir: Union[str, Path] = None, lowercase: bool = False, + batch_length: int = None, + batch_stride: int = None, + batch_delimiter: str = "\n\n", + save_scores: bool = True, + disable_progress_bar: bool = None, **kwargs, ) -> Path: """ - summarize_file - summarize a text file and save the summary to a file - - :param str or Path file_path: the path to the text file - :param str or Path output_dir: the directory to save the summary to, defaults to None (current working directory) - :param int batch_length: number of tokens to use in each batch, defaults to None (self.token_batch_length) - :param int batch_stride: number of tokens to stride between batches, defaults to None (self.batch_stride) - :param bool lowercase: whether to lowercase the text prior to summarization, defaults to False - - :return Path: the path to the summary file + summarize_file - generate a summary for a text file + + :param Union[str, Path] file_path: The path to the text file. + :param Union[str, Path] output_dir: The path to the output directory, defaults to None + :param bool lowercase: whether to lowercase the text, defaults to False + :param int batch_length: Number of tokens to use in each batch, defaults to None + :param int batch_stride: Number of tokens to stride between batches, defaults to None + :param str batch_delimiter: Text delimiter between output summary batches, defaults to "\n\n" + :param bool save_scores: Whether to save the scores to the output file, defaults to True + :param bool disable_progress_bar: disable the progress bar, defaults to None + :return Path: The path to the output file """ file_path = Path(file_path) output_dir = Path(output_dir) if output_dir is not None else Path.cwd() - output_file = output_dir / f"{file_path.stem}_summary.txt" - with open(file_path, "r") as f: + with open(file_path, "r", encoding="utf-8") as f: text = clean(f.read(), lower=lowercase) + # Generate summaries using token batches gen_summaries = self.summarize_via_tokenbatches( text, batch_length=batch_length, batch_stride=batch_stride, + disable_progress_bar=disable_progress_bar, **kwargs, ) + # Save the generated summaries to the output file + output_file = output_dir / f"{file_path.stem}_summary.txt" self.save_summary( gen_summaries, output_file, + batch_delimiter=batch_delimiter, + save_scores=save_scores, ) return output_file def save_params( self, - output_path: str or Path = None, + output_path: Union[str, Path] = None, hf_tag: str = None, verbose: bool = False, ) -> None: @@ -532,7 +579,7 @@ def save_params( save_params - save the parameters of the run to a json file :param dict params: parameters to save - :param str or Path output_path: directory or filepath to save the parameters to + :param Union[str, Path] output_path: directory or filepath to save the parameters to :param str hf_tag: the model tag on huggingface (will be used instead of self.model_name_or_path) :param bool verbose: whether to log the parameters @@ -563,3 +610,30 @@ def save_params( if verbose: self.logger.info(f"parameters: {exported_params}") print(f"saved parameters to {metadata_path}") + + def __call__(self, input_data, **kwargs): + """ + Smart __call__ function to decide where to route the inputs based on whether a valid filepath is passed. + + :param input_data: Can be either a string (text to summarize) or a file path. + :param kwargs: Additional keyword arguments to pass to the summarization methods. + :return: The summary of the input text, or saves the summary to a file if a file path is provided. + + Example usage: + summarizer = Summarizer() + summary = summarizer("This is a test string to summarize.") + # or + summary = summarizer("/path/to/textfile.txt") + """ + if ( + len(str(input_data)) < 1000 # assume > 1000 characters is plaintext + and isinstance(input_data, (str, Path)) + and Path(input_data).is_file() + ): + self.logger.debug("Summarizing from file...") + return self.summarize_file(file_path=input_data, **kwargs) + elif isinstance(input_data, str): + self.logger.debug("Summarizing from string...") + return self.summarize_string(input_text=input_data, **kwargs) + else: + raise ValueError("Input must be a valid string or a file path.") diff --git a/src/textsum/utils.py b/src/textsum/utils.py index 92bf577..a102804 100644 --- a/src/textsum/utils.py +++ b/src/textsum/utils.py @@ -176,7 +176,7 @@ def setup_logging(loglevel, logfile=None) -> None: logfile = Path(logfile) loglevel = ( logging.INFO - if not loglevel in [logging.DEBUG, logging.INFO, logging.WARNING] + if loglevel not in [logging.DEBUG, logging.INFO, logging.WARNING] else loglevel ) if loglevel == logging.DEBUG: @@ -218,7 +218,7 @@ def check_bitsandbytes_available(): check_bitsandbytes_available - check if the bitsandbytes library is available """ try: - import bitsandbytes + import bitsandbytes # noqa: F401 except ImportError: return False return True