Skip to content

Commit

Permalink
Fixed piccolo
Browse files Browse the repository at this point in the history
  • Loading branch information
gauravr committed Jan 10, 2025
1 parent a873671 commit 0b36460
Showing 1 changed file with 43 additions and 40 deletions.
83 changes: 43 additions & 40 deletions apphelpers/rest/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,45 @@ def f_wrapped(*args, **kw):
return wrapper


if peewee_enabled:

def dbtransaction(db):
"""
wrapper that make db transactions automic
note db connections are used only when it is needed (hence there is no
usual connection open/close)
"""

def wrapper(f):
if inspect.iscoroutinefunction(f):

@wraps(f)
async def async_wrapper(*ar, **kw):
with dbtransaction_ctx(db):
return await f(*ar, **kw)

return async_wrapper
else:

@wraps(f)
async def sync_wrapper(*ar, **kw):
with dbtransaction_ctx(db):
return f(*ar, **kw)

return sync_wrapper

return wrapper

else:
# for piccolo db
def dbtransaction(engine, allow_nested=True):
async def dependency():
async with dbtransaction_ctx(engine, allow_nested=allow_nested):
yield

return Depends(dependency)


async def get_current_user(request: Request):
return request.state.user if request.state.user.id else None

Expand Down Expand Up @@ -168,45 +207,6 @@ async def get_user_ip(request: Request):
header = Annotated[str, Header()]


if peewee_enabled:

def dbtransaction(db):
"""
wrapper that make db transactions automic
note db connections are used only when it is needed (hence there is no
usual connection open/close)
"""

def wrapper(f):
if inspect.iscoroutinefunction(f):

@wraps(f)
async def async_wrapper(*ar, **kw):
with dbtransaction_ctx(db):
return await f(*ar, **kw)

return async_wrapper
else:

@wraps(f)
async def sync_wrapper(*ar, **kw):
with dbtransaction_ctx(db):
return f(*ar, **kw)

return sync_wrapper

return wrapper

else:

def dbtransaction(engine, allow_nested=True):
async def dependency():
async with dbtransaction_ctx(engine, allow_nested=allow_nested):
yield

return Depends(dependency)


class SecureRouter(APIRoute):
sessions = None

Expand Down Expand Up @@ -377,7 +377,10 @@ def enable_multi_site(self, site_identifier: str):
self.site_identifier = site_identifier

def setup_db_transaction(self, db):
self.db_tr_wrapper = dbtransaction(db)
if peewee_enabled:
self.db_tr_wrapper = dbtransaction(db)
else:
self.router.dependencies.append(dbtransaction(db))

def setup_honeybadger_monitoring(self):
api_key = settings.HONEYBADGER_API_KEY
Expand Down

0 comments on commit 0b36460

Please sign in to comment.