diff --git a/config/service/marie-extract.yml b/config/service/marie-extract.yml new file mode 100644 index 00000000..e9e32835 --- /dev/null +++ b/config/service/marie-extract.yml @@ -0,0 +1,215 @@ +jtype: Flow +version: '1' +protocol: grpc + +# Shared configuration +shared_config: + storage: &storage + psql: &psql_conf_shared + provider: postgresql + hostname: 127.0.0.1 + port: 5432 + username: postgres + password: 123456 + database: postgres + default_table: shared_docs + + message: &message + amazon_mq : &amazon_mq_conf_shared + provider: amazon-rabbitmq + hostname: ${{ ENV.AWS_MQ_HOSTNAME }} + port: 15672 + username: ${{ ENV.AWS_MQ_USERNAME }} + password: ${{ ENV.AWS_MQ_PASSWORD }} + tls: True + virtualhost: / + + + rabbitmq : &rabbitmq_conf_shared + provider: rabbitmq + hostname: ${{ ENV.RABBIT_MQ_HOSTNAME }} + port: ${{ ENV.RABBIT_MQ_PORT }} + username: ${{ ENV.RABBIT_MQ_USERNAME }} + password: ${{ ENV.RABBIT_MQ_PASSWORD }} + tls: False + virtualhost: / + + +# Toast event tracking system +# It can be backed by Message Queue and Database backed +toast: + native: + enabled: True + path: /tmp/marie/events.json + rabbitmq: + <<: *rabbitmq_conf_shared + enabled : True + psql: + <<: *psql_conf_shared + default_table: event_tracking + enabled : True + +# Document Storage +# The storage service is used to store the data that is being processed +# Storage can be backed by S3 compatible + +storage: + # S3 configuration. Will be used only if value of backend is "s3" + s3: + enabled: True + metadata_only: False # If True, only metadata will be stored in the storage backend + # api endpoint to connect to. use AWS S3 or any S3 compatible object storage endpoint. + endpoint_url: ${{ ENV.S3_ENDPOINT_URL }} + # optional. + # access key id when using static credentials. + access_key_id: ${{ ENV.S3_ACCESS_KEY_ID }} + # optional. + # secret key when using static credentials. + secret_access_key: ${{ ENV.S3_SECRET_ACCESS_KEY }} + # Bucket name in s3 + bucket_name: ${{ ENV.S3_BUCKET_NAME }} + # optional. + # Example: "region: us-east-2" + region: ${{ ENV.S3_REGION }} + # optional. + # enable if endpoint is http + insecure: True + # optional. + # enable if you want to use path style requests + addressing_style: path + + # postgresql configuration. Will be used only if value of backend is "psql" + psql: + <<: *psql_conf_shared + default_table: store_metadata + enabled : False + +# Job Queue scheduler +scheduler: + psql: + <<: *psql_conf_shared + default_table: job_queue + enabled : True + +# FLOW / GATEWAY configuration + +with: + port: + - 51000 + - 52000 + protocol: + - http + - grpc + discovery: True + discovery_host: 127.0.0.1 + discovery_port: 8500 + + host: 127.0.0.1 + + # monitoring + monitoring: true + port_monitoring: 57844 + + event_tracking: True + + expose_endpoints: + /document/extract: + methods: ["POST"] + summary: Extract data-POC + tags: + - extract + /status: + methods: ["POST"] + summary: Status + tags: + - extract + + /text/status: + methods: ["POST"] + summary: Extract data + tags: + - extract + + /ner/extract: + methods: ["POST"] + summary: Extract NER + tags: + - ner + + /document/classify: + methods: ["POST"] + summary: Classify document at page level + tags: + - classify + +prefetch: 4 + +executors: + - name: extract_t + uses: + jtype: TextExtractionExecutor +# jtype: TextExtractionExecutorMock + with: + storage: + # postgresql configuration. Will be used only if value of backend is "psql" + psql: + <<: *psql_conf_shared + default_table: extract_metadata + enabled: True + pipelines: + - pipeline: + name: 'default' + default: True + page_classifier: + - model_name_or_path: 'marie/lmv3-medical-document-classification' + type: 'transformers' + task: 'text-classification-multimodal' + device: 'cuda' + enabled: True + batch_size: 1 + name: 'medical_page_classifier' + group: 'medical-page-classifier' + + - model_name_or_path: 'marie/lmv3-medical-document-payer' + type: 'transformers' + task: 'text-classification-multimodal' + enabled: True + batch_size: 1 + device: 'cuda' + name: 'medical_payer_classifier' + group: 'medical-payer-classifier' + + page_indexer: + - model_name_or_path: 'rms/layoutlmv3-large-corr-ner' + enabled: True + type: 'transformers' + device: 'cuda' + name: 'page_indexer_patient' + group: 'medical-payer-classifier' + filter: + type: 'regex' + pattern: '.*' + page_splitter: + model_name_or_path: 'marie/layoutlmv3-medical-document-splitter' + enabled: True + metas: + py_modules: + - marie.executor.text + timeout_ready: 3000000 + replicas: 1 + # replicas: ${{ CONTEXT.gpu_device_count }} + env: + CUDA_VISIBLE_DEVICES: RR +# Authentication and Authorization configuration + +auth: + keys: + - name : service-A + api_key : mas_0aPJ9Q9nUO1Ac1vJTfffXEXs9FyGLf9BzfYgZ_RaHm707wmbfHJNPQ + enabled : True + roles : [admin, user] + + - name : service-B + api_key : mau_t6qDi1BcL1NkLI8I6iM8z1va0nZP01UQ6LWecpbDz6mbxWgIIIZPfQ + enabled : True + roles : [admin, user] diff --git a/config/tests-integration/pipeline-classify-004.partial.yml b/config/tests-integration/pipeline-classify-004.partial.yml new file mode 100644 index 00000000..7520d7fe --- /dev/null +++ b/config/tests-integration/pipeline-classify-004.partial.yml @@ -0,0 +1,129 @@ +pipelines: + - pipeline: + name: 'default' # name of the pipeline, used for logging and asset saving + default: True + id2label: + '0': additional_information + '1': attorney + '2': auth_approval + '3': auth_denial + '4': bankruptcy + '5': cms_letter + '6': dispute + '7': eligibility + '8': medical_certificate + '9': medical_record + '10': misc + '11': newborn + '12': noop_blank + '13': noop_check + '14': noop_cover + '15': noop_envelope + '16': noop_eob + '17': noop_hicfa + '18': noop_notice + '19': noop_patpay + '20': noop_w9 + '21': pa_162 + '22': referral + '23': refund_request + '24': tax_1099 + + page_classifier: + - model_name_or_path: 'rms/corr-layoutlmv3-classifier' + name: 'corr_page_classifier_layoutlmv3' + type: 'transformers' + task: 'text-classification-multimodal' + device: 'cuda' + enabled: True + group: 'corr-classifier' + + - model_name_or_path: 'rms/corr-longformer-classifier' + task: 'text-classification' + name: 'corr_page_classifier_longformer' + type: 'transformers' + enabled: True + batch_size: 1 # batch size > 1 causes errors due to wrong batch aggregation + device: 'cuda' + group: 'corr-classifier' + - model_name_or_path: 'rms/corr-layoutlmv3-classifier' + name: 'corr_page_classifier_layoutlmv3' + type: 'transformers' + task: 'text-classification-multimodal' + device: 'cuda' + enabled: True + group: 'jpmc-classifier' +# - model_name_or_path: 'rms/corr-payer-longformer-classifier' +# task: 'text-classification' +# name: 'corr_payer_longformer' +# type: 'transformers' +# enabled: True +# batch_size: 1 # batch size > 1 causes errors due to wrong batch aggregation +# device: 'cuda' +# group: 'corr-payer-classifier' + + sub_classifier: + - model_name_or_path: 'rms/corr-auth-longformer-classifier' + task: 'text-classification' + name: 'corr_auth_sub_classifier' + type: 'transformers' + enabled: True + batch_size: 1 # batch size > 1 causes errors due to wrong batch aggregation + device: 'cuda' + group: 'corr-classifier' + + id2label: + '0': auth_denial_in + '1': auth_denial_op + # Filter should be on the same level as the sub-classifier, for now this is just a global filter + filter: + type: 'exact' + pattern: 'auth_denial' + page_indexer: + - model_name_or_path: 'rms/layoutlmv3-large-corr-ner' + enabled: True + type: 'transformers' + device: 'cuda' + name: 'page_indexer_patient' + filter: + type: 'regex' + pattern: '.*' + - pipeline: + default : false + name: 'jpmc-corr' + device: cuda + id2label: + '0': additional_information + '1': attorney + '2': auth_approval + '3': auth_denial + '4': bankruptcy + '5': cms_letter + '6': dispute + '7': eligibility + '8': medical_certificate + '9': medical_record + '10': misc + '11': newborn + '12': noop_blank + '13': noop_check + '14': noop_cover + '15': noop_envelope + '16': noop_eob + '17': noop_hicfa + '18': noop_notice + '19': noop_patpay + '20': noop_w9 + '21': pa_162 + '22': referral + '23': refund_request + '24': tax_1099 + + page_classifier: + - model_name_or_path: 'rms/corr-layoutlmv3-classifier' + name: 'corr_page_classifier_layoutlmv3' + type: 'transformers' + task: 'text-classification-multimodal' + device: 'cuda' + enabled: True + group: 'corr-classifier' diff --git a/config/tests-integration/pipeline-integration-001.partial.yml b/config/tests-integration/pipeline-integration-001.partial.yml new file mode 100644 index 00000000..b529e1df --- /dev/null +++ b/config/tests-integration/pipeline-integration-001.partial.yml @@ -0,0 +1,46 @@ +pipelines: + - pipeline: + name: 'default' + default: True + page_classifier: + - model_name_or_path: 'marie/lmv3-medical-document-classification' + type: 'transformers' + task: 'text-classification-multimodal' + device: 'cuda' + enabled: True + batch_size: 1 + name: 'medical_page_classifier' + group: 'medical-classifier' + + - model_name_or_path: 'marie/lmv3-medical-document-payer' + type: 'transformers' + task: 'text-classification-multimodal' + enabled: True + batch_size: 1 + device: 'cuda' + name: 'medical_payer_classifier' + group: 'medical-payer-classifier' + + page_indexer: + - model_name_or_path: 'rms/layoutlmv3-large-corr-ner' + enabled: True + type: 'transformers' + device: 'cuda' + name: 'page_indexer_patient' + filter: + type: 'regex' + pattern: '.*' + group: 'medical-classifier' + - model_name_or_path: 'rms/layoutlmv3-large-corr-ner' + enabled: True + type: 'transformers' + device: 'cuda' + name: 'page_indexer_payer' + filter: + type: 'regex' + pattern: '.*' + group: 'medical-payer-classifier' + + page_splitter: + model_name_or_path: 'marie/layoutlmv3-medical-document-splitter' + enabled: True \ No newline at end of file diff --git a/marie/executor/text/text_extraction_executor.py b/marie/executor/text/text_extraction_executor.py index f675d0fa..5b0bdff9 100644 --- a/marie/executor/text/text_extraction_executor.py +++ b/marie/executor/text/text_extraction_executor.py @@ -1,6 +1,6 @@ import os import warnings -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import numpy as np import torch @@ -40,7 +40,7 @@ def __init__( device: Optional[str] = None, num_worker_preprocess: int = 4, storage: dict[str, any] = None, - pipeline: dict[str, any] = None, + pipelines: List[dict[str, any]] = None, dtype: Optional[Union[str, torch.dtype]] = None, **kwargs, ): @@ -61,7 +61,7 @@ def __init__( logger.info(f"Starting executor : {self.__class__.__name__}") logger.info(f"Runtime args : {kwargs.get('runtime_args')}") logger.info(f"Storage config: {storage}") - logger.info(f"Pipeline config: {pipeline}") + logger.info(f"Pipelines config: {pipelines}") logger.info(f"Device : {device}") logger.info(f"use_cuda : {use_cuda}") logger.info(f"Num worker preprocess : {num_worker_preprocess}") @@ -101,7 +101,7 @@ def __init__( setup_torch_optimizations(num_threads=num_threads) self.show_error = True # show prediction errors - self.pipeline = ExtractPipeline(pipeline_config=pipeline, cuda=has_cuda) + self.pipeline = ExtractPipeline(pipelines_config=pipelines, cuda=has_cuda) instance_name = "not_defined" if kwargs is not None: diff --git a/marie/models/pix2pix/data/__init__.py b/marie/models/pix2pix/data/__init__.py index 3b4fd4ef..17cee9db 100644 --- a/marie/models/pix2pix/data/__init__.py +++ b/marie/models/pix2pix/data/__init__.py @@ -23,15 +23,13 @@ def find_dataset_using_name(dataset_name): and it is case-insensitive. """ - # FIXME : this needs to be fixes as the import_module is not working mapping = { 'single': SingleDataset, 'base': None } + # return mapping[dataset_name] - return mapping[dataset_name] - - dataset_filename = "data." + dataset_name + "_dataset" + dataset_filename = "marie.models.pix2pix.data." + dataset_name + "_dataset" datasetlib = importlib.import_module(dataset_filename) dataset = None diff --git a/marie/models/pix2pix/models/__init__.py b/marie/models/pix2pix/models/__init__.py index 5f4b615e..f362303c 100644 --- a/marie/models/pix2pix/models/__init__.py +++ b/marie/models/pix2pix/models/__init__.py @@ -31,15 +31,13 @@ def find_model_using_name(model_name): and it is case-insensitive. """ - # FIXME : this needs to be fixes as the import_module is not working mapping = { 'test': TestModel, 'base': BaseModel } + # return mapping[model_name] - return mapping[model_name] - - model_filename = "models." + model_name + "_model" + model_filename = "marie.models.pix2pix.models." + model_name + "_model" modellib = importlib.import_module(model_filename) model = None target_model_name = model_name.replace('_', '') + 'model' diff --git a/marie/pipe/classification_pipeline.py b/marie/pipe/classification_pipeline.py index a34db4a1..01502bcc 100644 --- a/marie/pipe/classification_pipeline.py +++ b/marie/pipe/classification_pipeline.py @@ -1,19 +1,18 @@ import os import shutil +import types from datetime import datetime from typing import List, Optional, Union import numpy as np -import torch from docarray import DocList from PIL import Image from marie.boxes import PSMode from marie.excepts import BadConfigSource from marie.logging.logger import MarieLogger -from marie.logging.profile import TimeContext from marie.models.utils import initialize_device_settings -from marie.ocr import CoordinateFormat, OcrEngine +from marie.ocr import CoordinateFormat from marie.ocr.util import get_known_ocr_engines, get_words_and_boxes from marie.pipe import ( ClassifierPipelineComponent, @@ -23,17 +22,16 @@ ) from marie.pipe.components import ( burst_frames, + load_pipeline, ocr_frames, + reload_pipeline, restore_assets, - setup_classifiers, - setup_indexers, - split_filename, store_assets, + store_metadata, ) from marie.pipe.voting import ClassificationResult, get_voting_strategy from marie.utils.docs import docs_from_image from marie.utils.image_utils import hash_frames_fast -from marie.utils.json import store_json_object from marie.utils.utils import ensure_exists @@ -71,8 +69,10 @@ def __init__( ) -> None: self.show_error = True # show prediction errors self.logger = MarieLogger(context=self.__class__.__name__) + self.load_pipeline = types.MethodType(load_pipeline, self) self.pipelines_config = pipelines_config + self.reload_pipeline = types.MethodType(reload_pipeline, self) self.default_pipeline_config = None self.silence_exceptions = silence_exceptions @@ -112,66 +112,6 @@ def __init__( self.default_pipeline_config, self.ocr_engines["default"] ) - def load_pipeline( - self, pipeline_config: dict[str, any], ocr_engine: Optional[OcrEngine] = None - ) -> tuple[str, dict[str, any], dict[str, any]]: - - # TODO : Need to refactor this (use the caller to get the device and then fallback to the pipeline config) - # sometimes we have CUDA/GPU support but want to only use CPU - use_cuda = torch.cuda.is_available() - if os.environ.get("MARIE_DISABLE_CUDA"): - use_cuda = False - device = pipeline_config.get("device", "cpu" if not use_cuda else "cuda") - if device == "cuda" and not use_cuda: - device = "cpu" - - if "name" not in pipeline_config: - raise BadConfigSource("Invalid pipeline config, missing name field") - - pipeline_name = pipeline_config["name"] - document_classifiers = setup_classifiers( - pipeline_config, key="page_classifier", device=device, ocr_engine=ocr_engine - ) - - document_sub_classifiers = setup_classifiers( - pipeline_config, key="sub_classifier", device=device, ocr_engine=ocr_engine - ) - - classifier_groups = dict() - for classifier_group, classifiers in document_classifiers.items(): - sub_classifiers = document_sub_classifiers.get(classifier_group, {}) - classifier_groups[classifier_group] = { - "group": classifier_group, - "classifiers": classifiers, - "sub_classifiers": sub_classifiers, - } - - document_indexers = setup_indexers( - pipeline_config, key="page_indexer", device=device, ocr_engine=ocr_engine - ) - - indexer_groups = dict() - for group, indexer in document_indexers.items(): - indexer_groups[group] = { - "group": group, - "indexer": indexer, - } - - # dump information about the loaded classifiers that are grouped by the classifier group - for classifier_group, classifiers in document_classifiers.items(): - self.logger.info( - f"Loaded classifiers :{classifier_group}, {len(classifiers)}, {classifiers.keys()}" - ) - for classifier_group, classifiers in document_sub_classifiers.items(): - self.logger.info( - f"Loaded sub-classifiers : {classifier_group}, {len(classifiers)}, {classifiers.keys()}" - ) - self.logger.info( - f"Loaded indexers : {len(document_indexers)}, {document_indexers.keys()}" - ) - - return pipeline_name, classifier_groups, indexer_groups - def execute_frames_pipeline( self, ref_id: str, @@ -275,34 +215,12 @@ def execute_frames_pipeline( } ) - self.store_metadata(ref_id, ref_type, root_asset_dir, metadata) + store_metadata(ref_id, ref_type, root_asset_dir, metadata) store_assets(ref_id, ref_type, root_asset_dir, match_wildcard="*.json") del metadata["ocr"] return metadata - def store_metadata( - self, - ref_id: str, - ref_type: str, - root_asset_dir: str, - metadata: dict[str, any], - infix: str = "meta", - ) -> None: - """ - Store current metadata for the document. Format is {ref_id}.meta.json in the root asset directory - :param ref_id: reference id of the document - :param ref_type: reference type of the document - :param root_asset_dir: root asset directory - :param metadata: metadata to store - :param infix: infix to use for the metadata file, default is "meta" e.g. {ref_id}.meta.json - :return: None - """ - filename, prefix, suffix = split_filename(ref_id) - metadata_path = os.path.join(root_asset_dir, f"{filename}.{infix}.json") - self.logger.info(f"Storing metadata : {metadata_path}") - store_json_object(metadata, metadata_path) - def execute( self, ref_id: str, @@ -494,34 +412,3 @@ def group_results_by_page(self, group_key: str, page_meta: List[dict[str, any]]) group_by_page[page].append(detail) return group_by_page - - def reload_pipeline(self, pipeline_name) -> None: - with TimeContext(f"### Reloading pipeline : {pipeline_name}", self.logger): - try: - self.logger.info(f"Reloading pipeline : {pipeline_name}") - if self.pipelines_config is None: - raise BadConfigSource( - "Invalid pipeline configuration, no pipelines found" - ) - - pipeline_config = None - for conf in self.pipelines_config: - conf = conf["pipeline"] - if conf.get("name") == pipeline_name: - pipeline_config = conf - break - - if pipeline_config is None: - raise BadConfigSource( - f"Invalid pipeline configuration, pipeline not found : {pipeline_name}" - ) - - ( - self.pipeline_name, - self.classifier_groups, - self.indexer_groups, - ) = self.load_pipeline(pipeline_config) - self.logger.info(f"Reloaded successfully pipeline : {pipeline_name} ") - except Exception as e: - self.logger.error(f"Error reloading pipeline : {e}") - raise e diff --git a/marie/pipe/components.py b/marie/pipe/components.py index 2b393abc..0a80c7eb 100644 --- a/marie/pipe/components.py +++ b/marie/pipe/components.py @@ -11,6 +11,7 @@ from marie.components import TransformersDocumentClassifier, TransformersDocumentIndexer from marie.excepts import BadConfigSource from marie.logging.predefined import default_logger as logger +from marie.logging.profile import TimeContext from marie.ocr import CoordinateFormat, OcrEngine from marie.overlay.overlay import NoopOverlayProcessor, OverlayProcessor from marie.storage import StorageManager @@ -286,6 +287,99 @@ def setup_indexers( return document_indexers +def load_pipeline( + self, pipeline_config: dict[str, any], ocr_engine: Optional[OcrEngine] = None +) -> tuple[str, dict[str, any], dict[str, any]]: + + # TODO : Need to refactor this (use the caller to get the device and then fallback to the pipeline config) + # sometimes we have CUDA/GPU support but want to only use CPU + use_cuda = torch.cuda.is_available() + if os.environ.get("MARIE_DISABLE_CUDA"): + use_cuda = False + device = pipeline_config.get("device", "cpu" if not use_cuda else "cuda") + if device == "cuda" and not use_cuda: + device = "cpu" + + if "name" not in pipeline_config: + raise BadConfigSource("Invalid pipeline config, missing name field") + + pipeline_name = pipeline_config["name"] + document_classifiers = setup_classifiers( + pipeline_config, key="page_classifier", device=device, ocr_engine=ocr_engine + ) + + document_sub_classifiers = setup_classifiers( + pipeline_config, key="sub_classifier", device=device, ocr_engine=ocr_engine + ) + + classifier_groups = dict() + for classifier_group, classifiers in document_classifiers.items(): + sub_classifiers = document_sub_classifiers.get(classifier_group, {}) + classifier_groups[classifier_group] = { + "group": classifier_group, + "classifiers": classifiers, + "sub_classifiers": sub_classifiers, + } + + document_indexers = setup_indexers( + pipeline_config, key="page_indexer", device=device, ocr_engine=ocr_engine + ) + + indexer_groups = dict() + for group, indexer in document_indexers.items(): + indexer_groups[group] = { + "group": group, + "indexer": indexer, + } + + # dump information about the loaded classifiers that are grouped by the classifier group + for classifier_group, classifiers in document_classifiers.items(): + self.logger.info( + f"Loaded classifiers :{classifier_group}, {len(classifiers)}, {classifiers.keys()}" + ) + for classifier_group, classifiers in document_sub_classifiers.items(): + self.logger.info( + f"Loaded sub-classifiers : {classifier_group}, {len(classifiers)}, {classifiers.keys()}" + ) + self.logger.info( + f"Loaded indexers : {len(document_indexers)}, {document_indexers.keys()}" + ) + + return pipeline_name, classifier_groups, indexer_groups + + +def reload_pipeline(self, pipeline_name) -> None: + with TimeContext(f"### Reloading pipeline : {pipeline_name}", self.logger): + try: + self.logger.info(f"Reloading pipeline : {pipeline_name}") + if self.pipelines_config is None: + raise BadConfigSource( + "Invalid pipeline configuration, no pipelines found" + ) + + pipeline_config = None + for conf in self.pipelines_config: + conf = conf["pipeline"] + if conf.get("name") == pipeline_name: + pipeline_config = conf + break + + if pipeline_config is None: + raise BadConfigSource( + f"Invalid pipeline configuration, pipeline not found : {pipeline_name}" + ) + + ( + self.pipeline_name, + self.classifier_groups, + self.indexer_groups, + ) = self.load_pipeline(pipeline_config, self.ocr_engines["default"]) + self.logger.info(f"Reloaded successfully pipeline : {pipeline_name} ") + except Exception as e: + self.logger.error(f"Error reloading pipeline : {e}") + raise e + + def restore_assets( ref_id: str, ref_type: str, @@ -372,6 +466,28 @@ def store_assets( logger.error(f"Error storing assets : {e}") +def store_metadata( + ref_id: str, + ref_type: str, + root_asset_dir: str, + metadata: dict[str, any], + infix: str = "meta", +) -> None: + """ + Store current metadata for the document. Format is {ref_id}.meta.json in the root asset directory + :param ref_id: reference id of the document + :param ref_type: reference type of the document + :param root_asset_dir: root asset directory + :param metadata: metadata to store + :param infix: infix to use for the metadata file, default is "meta" e.g. {ref_id}.meta.json + :return: None + """ + filename, prefix, suffix = split_filename(ref_id) + metadata_path = os.path.join(root_asset_dir, f"{filename}.{infix}.json") + logger.info(f"Storing metadata : {metadata_path}") + store_json_object(metadata, metadata_path) + + def burst_frames( ref_id: str, frames: List[np.ndarray], @@ -492,8 +608,3 @@ def ocr_frames( results = load_json_file(json_path) return results - - -# force False -# json_path /tmp/generators/f918a115f8474f63da92b4676483caf3/results/157154493_4.json -# json_path 157154493_4.json diff --git a/marie/pipe/extract_pipeline.py b/marie/pipe/extract_pipeline.py index c11067fb..cfa8ffa2 100644 --- a/marie/pipe/extract_pipeline.py +++ b/marie/pipe/extract_pipeline.py @@ -1,18 +1,20 @@ import glob import os import shutil +import types from datetime import datetime from pathlib import Path from typing import List, Optional, Union import numpy as np -import torch from docarray import DocList from PIL import Image from marie.boxes import PSMode from marie.common.file_io import get_file_count +from marie.excepts import BadConfigSource from marie.logging.logger import MarieLogger +from marie.models.utils import initialize_device_settings from marie.ocr import CoordinateFormat from marie.ocr.util import get_known_ocr_engines, get_words_and_boxes from marie.pipe import ( @@ -23,21 +25,22 @@ ) from marie.pipe.components import ( burst_frames, + load_pipeline, ocr_frames, + reload_pipeline, restore_assets, s3_asset_path, - setup_classifiers, - setup_indexers, setup_overlay, split_filename, store_assets, + store_metadata, ) +from marie.pipe.voting import ClassificationResult, get_voting_strategy from marie.renderer import PdfRenderer, TextRenderer from marie.renderer.adlib_renderer import AdlibRenderer from marie.renderer.blob_renderer import BlobRenderer from marie.utils.docs import docs_from_image, frames_from_file from marie.utils.image_utils import hash_frames_fast -from marie.utils.json import store_json_object from marie.utils.tiff_ops import merge_tiff, save_frame_as_tiff_g4 from marie.utils.utils import ensure_exists from marie.utils.zip_ops import merge_zip @@ -74,33 +77,66 @@ class ExtractPipeline: def __init__( self, - pipeline_config: dict[str, any] = None, + pipelines_config: List[dict[str, any]] = None, cuda: bool = True, + device: Optional[str] = "cuda", + silence_exceptions: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) self.show_error = True # show prediction errors + self.logger = MarieLogger(context=self.__class__.__name__) + self.load_pipeline = types.MethodType(load_pipeline, self) + + self.pipelines_config = pipelines_config + self.reload_pipeline = types.MethodType(reload_pipeline, self) + self.default_pipeline_config = None + self.silence_exceptions = silence_exceptions + + for conf in pipelines_config: + conf = conf["pipeline"] + if conf.get("default", False): + if self.default_pipeline_config is not None: + raise BadConfigSource( + "Invalid pipeline configuration, multiple defaults found" + ) + self.default_pipeline_config = conf + + if self.default_pipeline_config is None: + raise BadConfigSource("Invalid pipeline configuration, default not found") + # sometimes we have CUDA/GPU support but want to only use CPU - use_cuda = torch.cuda.is_available() - if os.environ.get("MARIE_DISABLE_CUDA"): - use_cuda = False + resolved_devices, _ = initialize_device_settings( + devices=[device], use_cuda=True, multi_gpu=False + ) + if len(resolved_devices) > 1: + self.logger.warning( + "Multiple devices are not supported in %s inference, using the first device %s.", + self.__class__.__name__, + resolved_devices[0], + ) + self.device = resolved_devices[0] + has_cuda = True if self.device.type.startswith("cuda") else False - device = pipeline_config.get("device", "cpu" if not use_cuda else "cuda") - if device == "cuda" and not use_cuda: - device = "cpu" + self.overlay_processor = setup_overlay(self.default_pipeline_config) + self.ocr_engines = get_known_ocr_engines( + device=self.device.type, engine="default" + ) - self.logger = MarieLogger(context=self.__class__.__name__) - self.overlay_processor = setup_overlay(pipeline_config) - self.ocr_engines = get_known_ocr_engines(device=device) - self.document_classifiers = setup_classifiers(pipeline_config) - self.document_indexers = setup_indexers(pipeline_config) + ( + self.pipeline_name, + self.classifier_groups, + self.indexer_groups, + ) = self.load_pipeline( + self.default_pipeline_config, self.ocr_engines["default"] + ) self.logger.info( - f"Loaded classifiers : {len(self.document_classifiers)}, {self.document_classifiers.keys()}" + f"Loaded classifiers : {len(self.classifier_groups)}, {self.classifier_groups.keys()}" ) self.logger.info( - f"Loaded indexers : {len(self.document_indexers)}, {self.document_indexers.keys()}" + f"Loaded indexers : {len(self.indexer_groups)}, {self.indexer_groups.keys()}" ) def segment( @@ -214,28 +250,20 @@ def execute_frames_pipeline( self.logger.info(f"Feature : page indexer enabled : {page_indexer_enabled}") self.logger.info(f"Feature : page cleaner enabled : {page_cleaner_enabled}") - post_processing_pipeline = [] - - if page_classifier_enabled: - post_processing_pipeline.append( - ClassifierPipelineComponent( - name="classifier_pipeline_component", - document_classifiers=self.document_classifiers, + # check if the current pipeline name is the default pipeline name + if "name" in runtime_conf: + expected_pipeline_name = runtime_conf["name"] + if expected_pipeline_name != self.pipeline_name: + self.logger.warning( + f"pipeline name : {expected_pipeline_name}, expected : {self.pipeline_name} , reloading pipeline" ) - ) - - if page_indexer_enabled: - post_processing_pipeline.append( - NamedEntityPipelineComponent( - name="ner_pipeline_component", - document_indexers=self.document_indexers, - ) - ) + self.reload_pipeline(expected_pipeline_name) metadata = { "ref_id": ref_id, "ref_type": ref_type, "job_id": job_id, + "pipeline": self.pipeline_name, "pages": f"{len(frames)}", } @@ -246,14 +274,58 @@ def execute_frames_pipeline( # burst frames into individual images burst_frames(ref_id, frames, root_asset_dir) - clean_frames = self.segment( ref_id, frames, root_asset_dir, enabled=page_cleaner_enabled ) ocr_results = ocr_frames(self.ocr_engines, ref_id, clean_frames, root_asset_dir) metadata["ocr"] = ocr_results + metadata["classifications"] = [] + metadata["indexers"] = [] + + # Extract and Classify now done in groups + for group, classifier_group in self.classifier_groups.items(): + self.logger.info( + f"Loaded extract pipeline/group : {self.pipeline_name}, {group}" + ) + document_classifiers = classifier_group["classifiers"] + sub_classifiers = classifier_group["sub_classifiers"] + post_processing_pipeline = [] + + if page_classifier_enabled: + post_processing_pipeline.append( + ClassifierPipelineComponent( + name="classifier_pipeline_component", + document_classifiers=document_classifiers, + ) + ) + + if page_indexer_enabled: + if group in self.indexer_groups: + document_indexers = self.indexer_groups[group]["indexer"] + post_processing_pipeline.append( + NamedEntityPipelineComponent( + name="ner_pipeline_component", + document_indexers=document_indexers, + ) + ) - self.execute_pipeline(post_processing_pipeline, frames, ocr_results, metadata) + results = self.execute_pipeline( + post_processing_pipeline, sub_classifiers, frames, ocr_results + ) + metadata["classifications"].append( + { + "group": group, + "classification": results["classifier"] + if "classifier" in results + else {}, + } + ) + metadata["indexers"].append( + { + "group": group, + "indexer": results["indexer"] if "indexer" in results else {}, + } + ) # TODO : Convert to execution pipeline self.render_pdf(ref_id, frames, ocr_results, root_asset_dir) @@ -261,21 +333,11 @@ def execute_frames_pipeline( self.render_adlib(ref_id, frames, ocr_results, root_asset_dir) self.pack_assets(ref_id, ref_type, root_asset_dir, metadata) - self.store_metadata(ref_id, ref_type, root_asset_dir, metadata) + store_metadata(ref_id, ref_type, root_asset_dir, metadata) store_assets(ref_id, ref_type, root_asset_dir) return metadata - def store_metadata( - self, ref_id: str, ref_type: str, root_asset_dir: str, metadata: dict[str, any] - ) -> None: - """ - Store current metadata for the document. Format is {ref_id}.meta.json in the root asset directory - """ - filename, prefix, suffix = split_filename(ref_id) - metadata_path = os.path.join(root_asset_dir, f"{filename}.meta.json") - store_json_object(metadata, metadata_path) - def execute_regions_pipeline( self, ref_id: str, @@ -317,6 +379,7 @@ def execute_regions_pipeline( "ref_id": ref_id, "ref_type": ref_type, "job_id": job_id, + "pipeline": self.pipeline_name, "pages": f"{len(frames)}", "ocr": results, } @@ -509,13 +572,12 @@ def pack_assets( def execute_pipeline( self, processing_pipeline: List[PipelineComponent], + sub_classifiers: dict[str, any], frames: List, ocr_results: dict, - metadata: dict, + # metadata: dict, ): - """Execute the post processing pipeline - TODO : This is temporary, we need to make this configurable - """ + """Execute the post processing pipeline""" words = [] boxes = [] documents = docs_from_image(frames) @@ -530,7 +592,7 @@ def execute_pipeline( assert len(words) == len(boxes) context = PipelineContext(pipeline_id="post_processing_pipeline") - context["metadata"] = metadata + context["metadata"] = {} # metadata for pipe in processing_pipeline: try: @@ -544,3 +606,98 @@ def execute_pipeline( documents = pipe_results.state except Exception as e: self.logger.error(f"Error executing pipe : {e}") + + # TODO : This is temporary, we need to make this configurable + self.logger.info("### ClassificationPipeline results") + self.logger.info(context["metadata"]) + + page_indexer_meta = ( + context["metadata"]["page_indexer"] + if "page_indexer" in context["metadata"] + else [] + ) + page_classifier_meta = ( + context["metadata"]["page_classifier"] + if "page_classifier" in context["metadata"] + else [] + ) + + for idx, page_result in enumerate(page_classifier_meta): + for detail in page_result["details"]: + page = int(detail["page"]) + classification = detail["classification"] + filtered_classifiers = {} + + for key, val in sub_classifiers.items(): + fileter_config = val["filter"] + filter_type = fileter_config["type"] + filter_pattern = fileter_config["pattern"] + + if filter_type == "exact" and classification == filter_pattern: + self.logger.info(f"Adding sub-classifier : {key}") + filtered_classifiers[key] = val + + if filtered_classifiers: + self.logger.info( + f"Filtered classifiers : {filtered_classifiers.keys()}" + ) + sub_classifier_pipeline = ClassifierPipelineComponent( + name="sub_classifier_pipeline", + document_classifiers=filtered_classifiers, + ) + + ctx = PipelineContext(pipeline_id="sub_classification_pipeline") + ctx["metadata"] = {} + pipe_results = sub_classifier_pipeline.run( + documents[page : page + 1], + ctx, + words=[words[page]], + boxes=[boxes[page]], + ) + detail["sub_classifier"] = ctx["metadata"]["page_classifier"] + + # TODO : Read from config + # Classification strategy: max_score, max_votes, max_score_with_diff + prediction_agent = "majority" + tie_break_policy = "best_with_diff" + voter = get_voting_strategy(prediction_agent, tie_break_policy, max_diff=0.25) + + class_by_page = self.group_results_by_page("classifier", page_classifier_meta) + score_by_page = {} + for page, details in class_by_page.items(): + score_by_page[page] = voter([ClassificationResult(**x) for x in details]) + + classifier_results = { + "strategy": prediction_agent, + "tie_break_policy": tie_break_policy, + "pages": {}, + } + + for page in list(class_by_page.keys()): + classifier_results["pages"][page] = { + "details": class_by_page[page], + "best": score_by_page[page], + } + + # Indexer results + indexer_by_page = self.group_results_by_page("indexer", page_indexer_meta) + indexer_results = {"strategy": "default", "pages": {}} + + for page in list(indexer_by_page.keys()): + indexer_results["pages"][page] = {"details": indexer_by_page[page]} + + return {"classifier": classifier_results, "indexer": indexer_results} + + def group_results_by_page(self, group_key: str, page_meta: List[dict[str, any]]): + """Group the results by page""" + group_by_page = {} + for idx, page_result in enumerate(page_meta): + indexer = page_result[group_key] + for detail in page_result["details"]: + page = int(detail["page"]) + if page not in group_by_page: + group_by_page[page] = [] + detail[group_key] = indexer + group_by_page[page].append(detail) + + return group_by_page diff --git a/tests/integration/check_extract_pipeline.py b/tests/integration/check_extract_pipeline.py index 583da121..07acaaf1 100644 --- a/tests/integration/check_extract_pipeline.py +++ b/tests/integration/check_extract_pipeline.py @@ -40,6 +40,9 @@ def setup_storage(): MDC.put("request_id", "test") img_path = "~/tmp/address-001.png" img_path = "~/tmp/analysis/marie-issues/107/195668453-0004.png" + img_path = "~/Desktop/11302023_21100_5_102_.tif" + # img_path = "~/Desktop/11302023_28082_5_452_.tif" + img_path = os.path.expanduser(img_path) # StorageManager.mkdir("s3://marie") @@ -51,9 +54,11 @@ def setup_storage(): # s3_path = s3_asset_path(ref_id=filename, ref_type="pid", include_filename=True) # StorageManager.write(img_path, s3_path, overwrite=True) - pipeline_config = load_yaml(os.path.join(__config_dir__, "tests-integration", "pipeline-integration.partial.yml")) + + config = load_yaml(os.path.join(__config_dir__, "tests-integration", "pipeline-integration-001.partial.yml")) # pipeline_config = load_yaml(os.path.join(__config_dir__, "tests-integration", "pipeline-integration-region.partial.yml")) - pipeline = ExtractPipeline(pipeline_config=pipeline_config["pipeline"], cuda=True) + pipelines_config = config["pipelines"] + pipeline = ExtractPipeline(pipelines_config=pipelines_config, cuda=True) regions = [ {