Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New drivers #1279

Merged
merged 5 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 115 additions & 34 deletions asyncdb/drivers/jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
from collections.abc import Iterable, Sequence
from pathlib import Path, PurePath
import threading
from functools import partial
import jaydebeapi
import jpype
Expand All @@ -17,12 +18,23 @@
from .sql import SQLDriver


class jdbc(SQLDriver, DatabaseBackend, ModelBackend):
class jdbc(
SQLDriver,
DatabaseBackend,
ModelBackend
):
_provider = "JDBC"
_syntax = "sql"

def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None:
def __init__(
self,
dsn: str = "",
loop: asyncio.AbstractEventLoop = None,
params: dict = None,
**kwargs
) -> None:
self._test_query = "SELECT 1"
self.max_memory: int = kwargs.pop('max_memory', 12000)
try:
if isinstance(params["classpath"], str):
params["classpath"] = Path(params["classpath"])
Expand Down Expand Up @@ -94,10 +106,14 @@ def start_jvm(self, jarpath):
classpath = None
path = ";".join(jarpath)
_jvmArgs.append("-Djava.class.path=" + path)
_jvmArgs.append("-Xmx12000m")
_jvmArgs.append(f"-Xmx{self.max_memory}m")
_jvmArgs.append("-Dfile.encoding=UTF8")
Comment on lines 108 to 110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): We've found these issues:

Suggested change
_jvmArgs.append("-Djava.class.path=" + path)
_jvmArgs.append("-Xmx12000m")
_jvmArgs.append(f"-Xmx{self.max_memory}m")
_jvmArgs.append("-Dfile.encoding=UTF8")
_jvmArgs.append(f"-Djava.class.path={path}")
_jvmArgs.extend((f"-Xmx{self.max_memory}m", "-Dfile.encoding=UTF8"))

jpype.startJVM(
jvmpath=jpype.getDefaultJVMPath(), classpath=[classpath], *_jvmArgs, interrupt=True, convertStrings=True
jvmpath=jpype.getDefaultJVMPath(),
classpath=[classpath],
*_jvmArgs,
interrupt=True,
convertStrings=True
)

async def connection(self):
Expand All @@ -116,11 +132,16 @@ async def connection(self):
jpype.java.lang.ClassLoader.getSystemClassLoader()
)
if "options" in self._params:
options = ";".join({f"{k}={v}" for k, v in self._params["options"].items()})
options = ";".join(
{f"{k}={v}" for k, v in self._params["options"].items()}
)
self._dsn = f"{self._dsn};{options}"
user = self._params["user"]
password = self._params["password"]
self._executor = self.get_executor(executor=None, max_workers=10)
self._executor = self.get_executor(
executor="thread",
max_workers=10
)
self._connection = await self._thread_func(
jaydebeapi.connect,
self._classname,
Expand All @@ -130,50 +151,89 @@ async def connection(self):
executor=self._executor,
)
if self._connection:
print(f'{self._provider}: Connected at {self._params["driver"]}:{self._params["host"]}')
print(
f'{self._provider}: Connected at {self._params["driver"]}:{self._params["host"]}'
)
self._connected = True
self._initialized_on = time.time()
if self._init_func is not None and callable(self._init_func):
await self._init_func(self._connection) # pylint: disable=E1102
await self._init_func(self._connection) # pylint: disable=E1102 # no-qa
except jpype.JException as ex:
if "does not exist" in str(ex):
raise DriverError(
f"Database does not exist: {self._params.get('database')}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): Explicitly raise from a previous error (raise-from-previous-error)

Suggested change
)
) from ex

print(ex.stacktrace())
self._logger.error(f"Driver {self._classname} Error: {ex}")
self._logger.error(
f"Driver {self._classname} Error: {ex}"
)
except TypeError as e:
raise DriverError(f"Driver {self._classname} was not found: {e}") from e
raise DriverError(
f"Driver {self._classname} was not found: {e}"
) from e
except Exception as e:
self._logger.exception(e, stack_info=True)
raise DriverError(f"JDBC Unknown Error: {e!s}") from e
raise DriverError(
f"JDBC Unknown Error: {e!s}"
) from e
return self

connect = connection

async def close(self, timeout: int = 10) -> None:
async def close(self, timeout: int = 5) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Move JVM shutdown logic to a separate method

The JVM shutdown logic in the close method is quite complex. Consider moving this logic into a separate method for better organization and potential reuse in other parts of the driver.

async def close(self, timeout: int = 5) -> None:
    await self._shutdown_jvm(timeout)

async def _shutdown_jvm(self, timeout: int) -> None:
    print("JVM started: ", jpype.isJVMStarted())
    if not self._connected or not self._connection:
        return

print("JVM started: ", jpype.isJVMStarted())
if not self._connected or not self._connection:
print('Connection already closed.')
return # Prevent double close
try:
if self._connection:
close = self._thread_func(self._connection.close)
close = self._thread_func(
self._connection.close,
executor=self._executor
)
await asyncio.wait_for(close, timeout)
print(f'{self._provider}: Closed connection to {self._params["driver"]}:{self._params["host"]}')
self._connected = False
self._connection = None
except Exception as e:
print(e)
self._logger.exception(e, stack_info=True)
raise DriverError(f"JDBC Closing Error: {e!s}") from e
raise DriverError(
f"JDBC Closing Error: {e!s}"
) from e
finally:
self._connected = False
self._connection = None
# Shutdown the executor
if self._executor:
self._executor.shutdown(wait=True)
self._executor = None
# Detach all threads before shutting down JVM
if jpype.isThreadAttachedToJVM():
jpype.detachThreadFromJVM()
# Ensure JVM shutdown is called from the main thread
if jpype.isJVMStarted() and threading.current_thread() is threading.main_thread():
try:
# Force garbage collection on the Java side
jpype.java.lang.System.gc()
jpype.shutdownJVM()
self._logger.info(
'JDBC: JVM shutdown successfully.'
)
except Exception as e:
self._logger.warning(
f"Error shutting down JVM: {e}"
)

disconnect = close

def __del__(self) -> None:
try:
if jpype.isThreadAttachedToJVM():
jpype.detachThreadFromJVM()
jpype.shutdownJVM()
except Exception as e:
self._logger.exception(e, stack_info=True)

def get_columns(self):
return self._columns

async def _query(self, sentence, cursor: Any, fetch: Any, *args, **kwargs) -> Iterable:
async def _query(
self,
sentence,
cursor: Any,
fetch: Any,
*args,
**kwargs
) -> Iterable:
loop = asyncio.get_event_loop()

def _execute(sentence, cursor, fetch, *args, **kwargs):
Expand Down Expand Up @@ -210,7 +270,10 @@ async def query(self, sentence: str, **kwargs):
cursor = None
await self.valid_operation(sentence)
try:
cursor = await self._thread_func(self._connection.cursor)
cursor = await self._thread_func(
self._connection.cursor,
executor=self._executor
)
rows = await self._query(sentence, cursor, cursor.fetchall, **kwargs)
self._result = [dict(zip(self._columns, row)) for row in rows]
if not self._result:
Expand All @@ -230,7 +293,10 @@ async def fetch_all(self, sentence: str, **kwargs) -> Iterable:
result = None
await self.valid_operation(sentence)
try:
cursor = await self._thread_func(self._connection.cursor)
cursor = await self._thread_func(
self._connection.cursor,
executor=self._executor
)
result = await self._query(sentence, cursor, cursor.fetchall, **kwargs)
if not result:
return NoDataFound()
Expand All @@ -248,7 +314,10 @@ async def queryrow(self, sentence: str, **kwargs):
cursor = None
await self.valid_operation(sentence)
try:
cursor = await self._thread_func(self._connection.cursor)
cursor = await self._thread_func(
self._connection.cursor,
executor=self._executor
)
row = await self._query(sentence, cursor, cursor.fetchone, **kwargs)
self._result = dict(zip(self._columns, row))
if not self._result:
Expand All @@ -270,7 +339,10 @@ async def fetch_one(self, sentence: str, **kwargs) -> Iterable[Any]:
result = None
await self.valid_operation(sentence)
try:
cursor = await self._thread_func(self._connection.cursor)
cursor = await self._thread_func(
self._connection.cursor,
executor=self._executor
)
row = await self._query(sentence, cursor, cursor.fetchone, **kwargs)
result = dict(zip(self._columns, row))
if not result:
Expand All @@ -291,7 +363,10 @@ async def fetch_many(self, sentence: str, size: int = None, **kwargs) -> Iterabl
result = None
await self.valid_operation(sentence)
try:
cursor = await self._thread_func(self._connection.cursor)
cursor = await self._thread_func(
self._connection.cursor,
executor=self._executor
)
rows = await self._query(sentence, cursor, cursor.fetchmany, size=size, **kwargs)
result = [dict(zip(self._columns, row)) for row in rows]
if not result:
Expand All @@ -311,7 +386,10 @@ async def execute(self, sentence: str, *args, **kwargs) -> Union[None, Sequence]
result = None
await self.valid_operation(sentence)
try:
cursor = await self._thread_func(self._connection.cursor)
cursor = await self._thread_func(
self._connection.cursor,
executor=self._executor
)
result = await self._execute(sentence, cursor, *args, **kwargs)
return result
except Exception as err:
Expand All @@ -327,7 +405,10 @@ async def execute_many(self, sentence: Union[str, list], *args, **kwargs) -> Uni
result = None
await self.valid_operation(sentence)
try:
cursor = await self._thread_func(self._connection.cursor)
cursor = await self._thread_func(
self._connection.cursor,
executor=self._executor
)
if isinstance(sentence, list):
results = []
for st in sentence:
Expand Down
Loading