Skip to content

Commit

Permalink
Merge pull request #113 from dbt-labs/fix/order-total-rounding-errors
Browse files Browse the repository at this point in the history
Fix order total rounding errors in output
  • Loading branch information
gwenwindflower authored Apr 8, 2024
2 parents 83c5592 + ab5040b commit 468fdc5
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 43 deletions.
6 changes: 4 additions & 2 deletions jafgen/customers/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ def to_dict(self) -> dict[str, Any]:
"id": self.order_id,
"customer": self.customer.customer_id,
"ordered_at": self.day.date.isoformat(),
# "order_month": self.day.date.strftime("%Y-%m"),
"store_id": self.store.store_id,
"subtotal": int(self.subtotal * 100),
"tax_paid": int(self.tax_paid * 100),
"order_total": int(self.order_total * 100),
# TODO: figure out why this is doesn't cause a test failure
# in tests/test_order_totals.py
# "order_total": int(self.order_total * 100),
"order_total": int(int(self.subtotal * 100) + int(self.tax_paid * 100)),
}

def items_to_dict(self) -> list[dict[str, Any]]:
Expand Down
39 changes: 16 additions & 23 deletions jafgen/simulation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import csv
import os
import uuid

import pandas as pd
from rich.progress import track

from jafgen.curves import Day
Expand Down Expand Up @@ -114,30 +114,23 @@ def run_simulation(self):
def save_results(self) -> None:
stock: Stock = Stock()
inventory: Inventory = Inventory()
entities: dict[str, pd.DataFrame] = {
"customers": pd.DataFrame.from_records(
[customer.to_dict() for customer in self.customers.values()]
),
"orders": pd.DataFrame.from_records(
[order.to_dict() for order in self.orders]
),
"items": pd.DataFrame.from_records(
[item.to_dict() for order in self.orders for item in order.items]
),
"stores": pd.DataFrame.from_records(
[market.store.to_dict() for market in self.markets]
),
"supplies": pd.DataFrame.from_records(stock.to_dict()),
"products": pd.DataFrame.from_records(inventory.to_dict()),
entities: dict[str, list[dict]] = {
"customers": [customer.to_dict() for customer in self.customers.values()],
"orders": [order.to_dict() for order in self.orders],
"items": [item.to_dict() for order in self.orders for item in order.items],
"stores": [market.store.to_dict() for market in self.markets],
"supplies": stock.to_dict(),
"products": inventory.to_dict(),
}

if not os.path.exists("./jaffle-data"):
os.makedirs("./jaffle-data")
# save output
for entity, df in track(
for entity, data in track(
entities.items(), description="🚚 Delivering jaffles..."
):
df.to_csv(
f"./jaffle-data/{self.prefix}_{entity}.csv",
header=df.columns.to_list(),
index=False,
)
with open(
f"./jaffle-data/{self.prefix}_{entity}.csv", "w", newline=""
) as file:
writer = csv.DictWriter(file, fieldnames=data[0].keys())
writer.writeheader()
writer.writerows(data)
3 changes: 1 addition & 2 deletions requirements.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
numpy
pandas
Faker
typer[all]
typer
18 changes: 4 additions & 14 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,22 @@
# uv pip compile requirements.in -o requirements.txt
click==8.1.7
# via typer
colorama==0.4.6
# via typer
faker==24.4.0
faker==24.7.1
markdown-it-py==3.0.0
# via rich
mdurl==0.1.2
# via markdown-it-py
numpy==1.26.4
# via pandas
pandas==2.2.1
pygments==2.17.2
# via rich
python-dateutil==2.9.0.post0
# via
# faker
# pandas
pytz==2024.1
# via pandas
# via faker
rich==13.7.1
# via typer
shellingham==1.5.4
# via typer
six==1.16.0
# via python-dateutil
typer==0.10.0
typing-extensions==4.10.0
typer==0.12.2
typing-extensions==4.11.0
# via typer
tzdata==2024.1
# via pandas
38 changes: 36 additions & 2 deletions tests/test_order_totals.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from jafgen.customers.order import Order
from jafgen.curves import Day
from jafgen.stores.store import Store
from jafgen.customers.customers import RemoteWorker
from jafgen.customers.customers import RemoteWorker, BrunchCrowd, Student
from jafgen.stores.inventory import Inventory


Expand All @@ -18,15 +18,49 @@ def test_order_totals():
items=[
inventory.get_food()[0],
inventory.get_drink()[0],
inventory.get_food()[0],
],
store=store,
day=Day(date_index=i),
)
)
orders.append(
Order(
customer=BrunchCrowd(store=store),
items=[
inventory.get_food()[0],
inventory.get_drink()[0],
inventory.get_food()[0],
],
store=store,
day=Day(date_index=i),
)
)
orders.append(
Order(
customer=Student(store=store),
items=[
inventory.get_food()[0],
inventory.get_drink()[0],
inventory.get_food()[0],
],
store=store,
day=Day(date_index=i),
)
)
for order in orders:
assert order.subtotal == order.items[0].item.price + order.items[1].item.price
assert (
order.subtotal
== order.items[0].item.price
+ order.items[1].item.price
+ order.items[2].item.price
)
assert order.tax_paid == order.subtotal * order.store.tax_rate
assert order.order_total == order.subtotal + order.tax_paid
assert round(float(order.order_total), 2) == round(
float(order.subtotal), 2
) + round(float(order.tax_paid), 2)
order_dict = order.to_dict()
assert (
order_dict["order_total"] == order_dict["subtotal"] + order_dict["tax_paid"]
)

0 comments on commit 468fdc5

Please sign in to comment.