Skip to content

Commit

Permalink
update deps, more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Dec 14, 2024
1 parent 64e46c1 commit eb7000b
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 7 deletions.
2 changes: 1 addition & 1 deletion examples/tips.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@
)

ray_results = ray_ctx.plan(df.execution_plan())
df_ctx.create_dataframe([[ray_results]]).show()
df_ctx.create_dataframe([ray_results]).show()
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ classifiers = [
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [
"datafusion>=42.0.0",
"pyarrow>=11.0.0",
"datafusion>=43.0.0",
"pyarrow>=18.0.0",
"typing-extensions;python_version<'3.13'",
]

Expand Down
4 changes: 2 additions & 2 deletions requirements-in.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ isort
maturin
mypy
numpy
pyarrow
pyarrow>=18.0.0
pytest
ray==2.37.0
datafusion>=42.0.0
datafusion>=43.0.0
toml
importlib_metadata; python_version < "3.8"
34 changes: 32 additions & 2 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

from datafusion_ray.context import DatafusionRayContext
from datafusion import SessionContext
from datafusion import SessionContext, SessionConfig, RuntimeConfig, col, lit, functions as F


def test_basic_query_succeed():
Expand All @@ -27,7 +27,7 @@ def test_basic_query_succeed():
record_batch = ctx.sql("SELECT * FROM tips")
assert record_batch.num_rows == 244

def test_aggregate():
def test_aggregate_csv():
df_ctx = SessionContext()
ctx = DatafusionRayContext(df_ctx)
df_ctx.register_csv("tips", "examples/tips.csv", has_header=True)
Expand All @@ -39,6 +39,36 @@ def test_aggregate():
num_rows += record_batch.num_rows
assert num_rows == 4

def test_aggregate_parquet():
runtime = RuntimeConfig()
config = SessionConfig().set('datafusion.execution.parquet.schema_force_view_types', 'true')
df_ctx = SessionContext(config, runtime)
ctx = DatafusionRayContext(df_ctx)
df_ctx.register_parquet("tips", "examples/tips.parquet")
record_batches = ctx.sql("select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker")
assert isinstance(record_batches, list)
# TODO why does this return many empty batches?
num_rows = 0
for record_batch in record_batches:
num_rows += record_batch.num_rows
assert num_rows == 4

def test_aggregate_parquet_dataframe():
df_ctx = SessionContext()
ray_ctx = DatafusionRayContext(df_ctx)
df = df_ctx.read_parquet(f"examples/tips.parquet")
df = (
df.aggregate(
[col("sex"), col("smoker"), col("day"), col("time")],
[F.avg(col("tip") / col("total_bill")).alias("tip_pct")],
)
.filter(col("day") != lit("Dinner"))
.aggregate([col("sex"), col("smoker")], [F.avg(col("tip_pct")).alias("avg_pct")])
)
ray_results = ray_ctx.plan(df.execution_plan())
df_ctx.create_dataframe([ray_results]).show()


def test_no_result_query():
df_ctx = SessionContext()
ctx = DatafusionRayContext(df_ctx)
Expand Down

0 comments on commit eb7000b

Please sign in to comment.