Skip to content

Commit

Permalink
ChequeDetection (#260)
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxxx-zh authored May 2, 2024
1 parent 13d0a11 commit 4793f78
Show file tree
Hide file tree
Showing 9 changed files with 3,736 additions and 0 deletions.
1,378 changes: 1,378 additions & 0 deletions advanced_tutorials/fraud_cheque_detection/1_feature_pipeline.ipynb

Large diffs are not rendered by default.

573 changes: 573 additions & 0 deletions advanced_tutorials/fraud_cheque_detection/2_training_pipeline.ipynb

Large diffs are not rendered by default.

1,241 changes: 1,241 additions & 0 deletions advanced_tutorials/fraud_cheque_detection/3_inference_pipeline.ipynb

Large diffs are not rendered by default.

27 changes: 27 additions & 0 deletions advanced_tutorials/fraud_cheque_detection/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
### Dataset Configuration
DATASET_NAME = "shivi/cheques_sample_data"

### Donut Configuration
DONUT_BASE_REPO = "naver-clova-ix/donut-base" #"nielsr/donut-base"
DONUT_FT_REPO = "shivi/donut-cheque-parser"
IMAGE_SIZE = [960, 720] # Input image size
MAX_LENGTH = 768 # Max generated sequence length for text decoder of Donut

### Task Tokens
TASK_START_TOKEN = "<parse-cheque>"
TASK_END_TOKEN = "<parse-cheque>"

### Training Configuration
BATCH_SIZE = 1
NUM_WORKERS = 4
MAX_EPOCHS = 30
VAL_CHECK_INTERVAL = 0.2
CHECK_VAL_EVERY_N_EPOCH = 1
GRADIENT_CLIP_VAL = 1.0
LEARNING_RATE = 3e-5
VERBOSE = True

### Hardware Configuration
ACCELERATOR = "gpu"
DEVICE_NUM = 1
PRECISION = 16
78 changes: 78 additions & 0 deletions advanced_tutorials/fraud_cheque_detection/features/cheque.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from textblob import TextBlob
from word2number import w2n

def spell_check(text):
"""
Checks and corrects the spelling of a given text.
Parameters:
- text (str): The text whose spelling is to be checked.
Returns:
- tuple: A tuple containing a boolean indicating if the original spelling was correct, and the corrected text.
"""
# Convert the text to lower case to standardize it
text_lower = text.lower()

# Return early if the text is 'missing' or an empty string
if text_lower in ['missing', ' ']:
return False, 'missing'

# Correct the text using TextBlob
text_corrected = str(TextBlob(text_lower).correct())

# Determine if the original text was spelled correctly
spelling_is_correct = text_lower == text_corrected

return spelling_is_correct, text_corrected.strip()


def amount_letter_number_match(amount_in_text_corrected, amount_in_number):
"""
Compares the numeric value of a text representation of an amount to its numeric counterpart.
Parameters:
- amount_in_text_corrected (str): The text representation of an amount.
- amount_in_number (str): The numeric representation of an amount.
Returns:
- bool or tuple: True if the amounts match, False otherwise, or a tuple with a message if data is missing.
"""
# Handle missing values
if 'missing' in [amount_in_text_corrected, amount_in_number]:
return False, ('Amount in words is missing' if amount_in_text_corrected == 'missing' else 'Amount in numbers is missing')

try:
# Attempt to convert the textual representation to a number
amount_text_to_num = w2n.word_to_num(amount_in_text_corrected)

# Compare it to the provided numeric value, making sure to convert it to an int
return amount_text_to_num == int(amount_in_number)

# If Spell correction fails (for -> four)
except Exception as e:
return False


def get_amount_match_column(amount_in_text_corrected, amount_in_number):
"""
Retrieves the match status or value for an amount, handling tuples indicating missing data.
Parameters:
- amount_in_text_corrected (str): The text representation of an amount, corrected for spelling.
- amount_in_number (str): The numeric representation of an amount.
Returns:
- bool: True if the amounts match, False otherwise, or the first element of the tuple if an error message is present.
"""
# Determine the match status or value
match_value = amount_letter_number_match(
amount_in_text_corrected,
amount_in_number,
)

# Return the value, handling tuple for error messages
if isinstance(match_value, tuple):
return match_value[0]

return match_value
138 changes: 138 additions & 0 deletions advanced_tutorials/fraud_cheque_detection/functions/donut.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import torch
import re
import numpy as np
from transformers import (
DonutProcessor,
VisionEncoderDecoderModel,
)
from features.cheque import (
spell_check,
amount_letter_number_match,
)

# Determine the device to use based on the availability of CUDA (GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"

def load_cheque_parser(folder_name):
"""
Loads a cheque parsing processor and model from a specified directory and moves the model to the appropriate device.
This function loads a DonutProcessor and a VisionEncoderDecoderModel. The model is moved to a GPU if available, otherwise to CPU.
Parameters:
- folder_name (str): The directory where the processor and model's pretrained weights are stored.
Returns:
- tuple: A tuple containing the loaded processor and model.
"""
# Load the DonutProcessor from the pretrained model directory
processor = DonutProcessor.from_pretrained(folder_name)

# Load the VisionEncoderDecoderModel from the pretrained model directory
model = VisionEncoderDecoderModel.from_pretrained(folder_name)

# Move the model to the available device (GPU or CPU)
model.to(device)

# Return the processor and model as a tuple
return processor, model


def parse_text(image, processor, model):
"""
Parses text from an image using a pre-trained model and processor, and formats the output.
Parameters:
- image: The image from which to parse text.
- processor: The processor instance equipped with methods for handling input preprocessing and decoding outputs.
- model: The pre-trained VisionEncoderDecoderModel used to generate text from image data.
Returns:
- dict: A dictionary containing parsed and formatted cheque details.
"""
# Prepare the initial task prompt and get decoder input IDs from the tokenizer
task_prompt = "<parse-cheque>"
decoder_input_ids = processor.tokenizer(
task_prompt,
add_special_tokens=False,
return_tensors="pt",
).input_ids

# Convert image to pixel values suitable for the model input
pixel_values = processor(image, return_tensors="pt").pixel_values


# Generate outputs from the model using the provided pixel values and decoder inputs
outputs = model.generate(
pixel_values.to(device), # Ensure pixel values are on the correct device (CPU/GPU)
decoder_input_ids=decoder_input_ids.to(device), # Move decoder inputs to the correct device
max_length=model.decoder.config.max_position_embeddings, # Set the maximum output length
pad_token_id=processor.tokenizer.pad_token_id, # Define padding token
eos_token_id=processor.tokenizer.eos_token_id, # Define end-of-sequence token
use_cache=True, # Enable caching to improve performance
bad_words_ids=[[processor.tokenizer.unk_token_id]], # Prevent generation of unknown tokens
return_dict_in_generate=True, # Return outputs in a dictionary format
)

# Decode the output sequences to text
sequence = processor.batch_decode(outputs.sequences)[0]

# Remove special tokens and clean up the sequence
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")

# Remove the initial task prompt token
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()

# Convert the cleaned sequence to JSON and format the output
json = processor.token2json(sequence)

return {
key: (value if value != '' else 'missing')
for attribute
in json['cheque_details']
for key, value
in attribute.items()
}


def evaluate_cheque_fraud(parsed_data, model_fraud_detection):
"""
Evaluates potential fraud in a cheque by analyzing the consistency of spelling and the match between
numerical and textual representations of the amount.
Parameters:
- parsed_data (dict): Dictionary containing parsed data from a cheque, including amounts in words and figures.
- model_fraud_detection: A trained model used to predict whether the cheque is valid or fraudulent.
Returns:
- tuple: A tuple containing the fraud evaluation result ('valid' or 'fraud'), spelling check message,
and amount match message.
"""
# Check spelling for the amount in words and correct it if necessary
spelling_is_correct, amount_in_text_corrected = spell_check(
parsed_data['amt_in_words'],
)

# Check if the corrected amount in words matches the amount in figures
amount_match = amount_letter_number_match(
amount_in_text_corrected,
parsed_data['amt_in_figures'],
)

# Handle the case where amount_match is a tuple, using only the first element if so
amount_match_value = amount_match[0] if isinstance(amount_match, tuple) else amount_match

# Prepare the input for the fraud detection model
model_input = np.array([spelling_is_correct, amount_match_value])

# Predict fraud using the model, reshaping input to match expected format
prediction = model_fraud_detection.predict(
model_input.reshape(1, -1)
)

# Construct messages regarding the spelling and value match
spelling = f'Spelling is correct: {spelling_is_correct}'
value_match = f'Numeric and alphabetic values match: {amount_match}'

# Return the evaluation result along with explanatory messages
return np.where(prediction[0] == 1, 'valid', 'fraud').item(), spelling, value_match
Loading

0 comments on commit 4793f78

Please sign in to comment.