-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert.py
254 lines (231 loc) · 10.1 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
### Imports ###
from pydantic import BaseModel, Field, field_validator
from typing import List
from enum import Enum
import math
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.chat_models import ChatOpenAI
import os
import openai
import typing
# local prompt_examples module
from prompt_examples import get_prompt_examples
'''
This object represents a single item (good/service) that was purchased in the receipt text.
'''
class Item(BaseModel):
description: str=Field(description="item name")
# unabbreviated to assist in classification (ORNG -> Orange)
unabbreviatedDescription: str=Field(default="", description="unabbreviated name of field:description")
# some items have options (eg, Happy Meals - Cheese Burger + Fries + Apples)
includedItems: List[str]=Field(default_factory=list)
quantity: int=Field(default=0, description="number of items")
unitPrice: float=Field(default=0.00, description="cost per unit")
totalPrice: float=Field(default=0.00, description="total cost of unit(s) purchased")
# discounts sometimes appear on receipts
discountAmount: float=Field(default=0.00, description="discount for item")
"""
Ensures all price items are floats
"""
@field_validator('unitPrice', 'totalPrice', 'discountAmount', mode='before')
@classmethod
def validate_float_Item(cls, input_value: typing.Any) -> float:
return_value = 0.00
# checks whether can convert to float
if (isinstance(input_value, str)):
try:
return_value = (float(input_value))
except:
return_value = 0.00
elif (isinstance(input_value, int)):
return_value = float(input_value)
elif (isinstance(input_value, float)):
return_value = input_value
return return_value
"""
Ensures all item counts are ints and if not, returns 0
"""
@field_validator('quantity', mode='before')
@classmethod
def validate_quantity(cls, quantity: typing.Any) -> int:
return_value = 0
# checks to ensure all instances are whole numberss/ints
if (isinstance(quantity, str)):
try:
return_value = math.ceil(float(quantity))
except:
return_value = 0
elif (isinstance(quantity, int)):
return_value = quantity
elif (isinstance(quantity, float)):
# if decimal, takes the upper bound
return_value = math.ceil(quantity)
return return_value
"""
Checks to ensure description/ abbreviated description has a value, and if it includes the word unknown
"""
@field_validator('description', 'unabbreviatedDescription', mode='after')
@classmethod
def validate_string_Item(cls, input_value: str) -> str:
return_value = input_value.replace("<UNKNOWN>", "")
return_value = return_value.replace("UNKNOWN", "")
return_value = " ".join(return_value.split())
return return_value
# removes unknowns from items
@field_validator('includedItems', mode='after')
@classmethod
def validate_includedItems(cls, includedItems):
return_array = []
for item in includedItems:
new_item = item.replace("<UNKNOWN>", "")
new_item = new_item.replace("UNKNOWN", "")
new_item = " ".join(new_item.split())
if new_item:
return_array.append(new_item)
return return_array
'''
This object represents the all of the information residing in one receipt text file.
Raw receipt text files are to be parsed into JSON object format for use in later analysis.
'''
class ReceiptInfo(BaseModel):
merchant: str=Field(description="name of merchant")
address: str=Field(description="address")
city: str=Field(description="city")
state: str=Field(description="state")
phoneNumber: str=Field(description="phone number")
receiptDate: str=Field(description="purchase date")
receiptTime: str=Field(description="time purchased")
totalItems: int=Field(description="number of items")
# assists in classifying restraunts vs grocery stores
diningOptions: str=Field(default="", description="here or to-go items for consumable items")
paymentType: str=Field(default="cash", description="payment method")
creditCardType: str=Field(default="", description="credit card type")
totalDiscount: float=Field(default=0.00, description="total discount")
tax: float=Field(description="tax amount")
total: float=Field(description="total amount paid")
ITEMS: List[Item]
"""
Creates totalItems as a whole number
"""
@field_validator('totalItems', mode='before')
@classmethod
def validate_totalItems(cls, totalItems: typing.Any) -> int:
return_value = 0
if (isinstance(totalItems, str)):
try:
return_value = math.ceil(float(totalItems))
except:
return_value = 0
elif (isinstance(totalItems, int)):
return_value = totalItems
elif (isinstance(totalItems, float)):
return_value = math.ceil(totalItems)
return return_value
"""
Ensures that the mode of payment is specified; default cash if no other indicator
"""
@field_validator('paymentType', mode='before')
def validate_paymentType(cls, paymentType: str) -> str:
return_value = 'cash'
try:
string = paymentType.lower()
credit_card_names = ['visa', 'discover', 'mastercard', 'american', 'express', 'amex', 'chase', 'citi', 'credit', 'card']
if 'debit' in string:
return_value = 'debit'
else:
for term in credit_card_names:
if term in string:
return_value = 'credit'
except:
pass
return return_value
"""
creates a diningOptions value, to allow for better differentiation
"""
@field_validator('diningOptions', mode='before')
@classmethod
def validate_diningOptions(cls, diningOptions: str) -> str:
return_value = ''
try:
string = diningOptions.lower()
dine_in_terms = ['for', 'here', 'dine', 'in', 'house', 'on']
to_go_terms = ['take', 'out', 'carry', 'to', 'go', 'pick', 'up', 'delivery', 'grab', 'away']
dine_in_score = sum([string.__contains__(term) for term in dine_in_terms])
to_go_score = sum([string.__contains__(term) for term in to_go_terms])
if (dine_in_score > to_go_score):
return_value = 'DINE IN'
elif (dine_in_score < to_go_score):
return_value = 'TO GO'
elif ((dine_in_score != 0) and (to_go_score != 0) and (dine_in_score == to_go_score)):
return_value = 'TO GO'
except:
pass
return return_value
"""
Ensures that the information is a float
"""
@field_validator('tax', 'total', 'totalDiscount', mode='before')
@classmethod
def validate_float_ReceiptInfo(cls, input_value: typing.Any) -> float:
return_value = 0.00
if (isinstance(input_value, str)):
try:
return_value = float(input_value)
except:
return_value = 0.00
elif (isinstance(input_value, int)):
return_value = float(input_value)
elif (isinstance(input_value, float)):
return_value = input_value
else:
return_value = 0.00
return return_value
"""
Ensures all entries are strings
"""
@field_validator('merchant', 'address', 'city', 'state', 'phoneNumber',
'receiptDate', 'receiptTime', 'creditCardType', mode='after')
@classmethod
def validate_string_ReceiptInfo(cls, input_value: str) -> str:
return_value = input_value.replace("<UNKNOWN>", "")
return_value = return_value.replace("UNKNOWN", "")
return_value = return_value.replace("Unknown", "")
return_value = return_value.replace("unknown", "")
return_value = " ".join(return_value.split())
return return_value
def make_receiptParser():
return PydanticOutputParser(pydantic_object=ReceiptInfo)
"""
Prompt prefix to send to the model to generate the JSON; and creates requirements so that it is generated as expected
"""
def get_prompt_prefix():
return '''You are a capable large language model.
Your task is to extract data from a given receipt and format it into the JSON schema below.
Use the default values if you're not sure. From the item description please predict the unabbreviatedDescription.
The values for the fields "description" and "unnabbreviatedDescription" can not be the same.
Please wrap all numeric values in double-quotes. Some items may be priced at a weighted rate, such as "per pound" or "per ounce".
Text can be used for multiple fields. Please use double-quotes for all string values.
If there are double-quotes inside string values, please escape those characters with the "\" character.
{format_instructions}
'''
def get_example_prompt(input_variables=["ExampleInput", "ExampleOutput"], template= "input:\n{ExampleInput}\noutput:\n{ExampleOutput}"):
return (PromptTemplate(input_variables = input_variables, template = template))
def get_suffix():
return "input:\n{input}\noutput:\n"
def make_fewshot_prompt(format_instructions):
return (FewShotPromptTemplate(
prefix = get_prompt_prefix(),
input_variables=["input"],
partial_variables={'format_instructions': format_instructions},
examples=get_prompt_examples(),
example_prompt = get_example_prompt(),
example_separator="\n",
suffix = get_suffix(),
))
def make_model(model="gpt-3.5-turbo-16k", temperature=1.00, openai_api_key="INSERT_OPENAI_API KEY"):
return ChatOpenAI(model=model, temperature=temperature, openai_api_key=openai_api_key)
def make_chain(fewshot_prompt, model, receiptParser):
chain = fewshot_prompt | model | receiptParser
return chain