diff --git a/connectorx-python/connectorx/__init__.py b/connectorx-python/connectorx/__init__.py index bf278c94e..a6a5acc6d 100644 --- a/connectorx-python/connectorx/__init__.py +++ b/connectorx-python/connectorx/__init__.py @@ -1,9 +1,11 @@ from __future__ import annotations -from typing import Any, Literal, TYPE_CHECKING, overload +import importlib from importlib.metadata import version +from typing import Any, Literal, TYPE_CHECKING, overload + from .connectorx import ( read_sql as _read_sql, partition_sql as _partition_sql, @@ -311,10 +313,7 @@ def read_sql( if return_type == "pandas": df = df.to_pandas(date_as_object=False, split_blocks=False) if return_type == "polars": - try: - import polars as pl - except ModuleNotFoundError: - raise ValueError("You need to install polars first") + pl = try_import_module("polars") try: # api change for polars >= 0.8.* @@ -350,10 +349,7 @@ def read_sql( conn, protocol = rewrite_conn(conn, protocol) if return_type in {"modin", "dask", "pandas"}: - try: - import pandas - except ModuleNotFoundError: - raise ValueError("You need to install pandas first") + try_import_module("pandas") result = _read_sql( conn, @@ -368,25 +364,14 @@ def read_sql( df.set_index(index_col, inplace=True) if return_type == "modin": - try: - import modin.pandas as mpd - except ModuleNotFoundError: - raise ValueError("You need to install modin first") - + mpd = try_import_module("modin.pandas") df = mpd.DataFrame(df) elif return_type == "dask": - try: - import dask.dataframe as dd - except ModuleNotFoundError: - raise ValueError("You need to install dask first") - + dd = try_import_module("dask.dataframe") df = dd.from_pandas(df, npartitions=1) elif return_type in {"arrow", "arrow2", "polars", "polars2"}: - try: - import pyarrow - except ModuleNotFoundError: - raise ValueError("You need to install pyarrow first") + try_import_module("pyarrow") result = _read_sql( conn, @@ -397,11 +382,7 @@ def read_sql( ) df = reconstruct_arrow(result) if return_type in {"polars", "polars2"}: - try: - import polars as pl - except ModuleNotFoundError: - raise ValueError("You need to install polars first") - + pl = try_import_module("polars") try: df = pl.DataFrame.from_arrow(df) except AttributeError: @@ -488,3 +469,10 @@ def remove_ending_semicolon(query: str) -> str: if query.endswith(";"): query = query[:-1] return query + + +def try_import_module(name: str): + try: + return importlib.import_module(name) + except ModuleNotFoundError: + raise ValueError(f"You need to install {name.split('.')[0]} first")