Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Dec 14, 2024
1 parent eb7000b commit 4b3ccf3
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
5 changes: 2 additions & 3 deletions datafusion_ray/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def execute_query_partition(
"ph": "X",
}
print(json.dumps(event), end=",")
return ret[0] if len(ret) == 1 else ret
return ret


class DatafusionRayContext:
Expand Down Expand Up @@ -143,7 +143,7 @@ def sql(self, sql: str) -> pa.RecordBatch:
df = self.df_ctx.sql(sql)
return self.plan(df.execution_plan())

def plan(self, execution_plan: Any) -> pa.RecordBatch:
def plan(self, execution_plan: Any) -> List[pa.RecordBatch]:

graph = self.ctx.plan(execution_plan)
final_stage_id = graph.get_final_query_stage().id()
Expand All @@ -161,4 +161,3 @@ def plan(self, execution_plan: Any) -> pa.RecordBatch:
# assert len(partitions) == 1, len(partitions)
result_set = ray.get(partitions[0])
return result_set

9 changes: 3 additions & 6 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def test_basic_query_succeed():
ctx = DatafusionRayContext(df_ctx)
df_ctx.register_csv("tips", "examples/tips.csv", has_header=True)
# TODO why does this return a single batch and not a list of batches?
record_batch = ctx.sql("SELECT * FROM tips")
assert record_batch.num_rows == 244
record_batches = ctx.sql("SELECT * FROM tips")
assert record_batches[0].num_rows == 244

def test_aggregate_csv():
df_ctx = SessionContext()
Expand All @@ -40,13 +40,10 @@ def test_aggregate_csv():
assert num_rows == 4

def test_aggregate_parquet():
runtime = RuntimeConfig()
config = SessionConfig().set('datafusion.execution.parquet.schema_force_view_types', 'true')
df_ctx = SessionContext(config, runtime)
df_ctx = SessionContext()
ctx = DatafusionRayContext(df_ctx)
df_ctx.register_parquet("tips", "examples/tips.parquet")
record_batches = ctx.sql("select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker")
assert isinstance(record_batches, list)
# TODO why does this return many empty batches?
num_rows = 0
for record_batch in record_batches:
Expand Down

0 comments on commit 4b3ccf3

Please sign in to comment.