diff --git a/datafusion_ray/context.py b/datafusion_ray/context.py index f2ef86f..0070220 100644 --- a/datafusion_ray/context.py +++ b/datafusion_ray/context.py @@ -115,7 +115,7 @@ def execute_query_partition( "ph": "X", } print(json.dumps(event), end=",") - return ret[0] if len(ret) == 1 else ret + return ret class DatafusionRayContext: @@ -143,7 +143,7 @@ def sql(self, sql: str) -> pa.RecordBatch: df = self.df_ctx.sql(sql) return self.plan(df.execution_plan()) - def plan(self, execution_plan: Any) -> pa.RecordBatch: + def plan(self, execution_plan: Any) -> List[pa.RecordBatch]: graph = self.ctx.plan(execution_plan) final_stage_id = graph.get_final_query_stage().id() @@ -161,4 +161,3 @@ def plan(self, execution_plan: Any) -> pa.RecordBatch: # assert len(partitions) == 1, len(partitions) result_set = ray.get(partitions[0]) return result_set - diff --git a/tests/test_context.py b/tests/test_context.py index 58c413e..ecc3324 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -24,8 +24,8 @@ def test_basic_query_succeed(): ctx = DatafusionRayContext(df_ctx) df_ctx.register_csv("tips", "examples/tips.csv", has_header=True) # TODO why does this return a single batch and not a list of batches? - record_batch = ctx.sql("SELECT * FROM tips") - assert record_batch.num_rows == 244 + record_batches = ctx.sql("SELECT * FROM tips") + assert record_batches[0].num_rows == 244 def test_aggregate_csv(): df_ctx = SessionContext() @@ -40,13 +40,10 @@ def test_aggregate_csv(): assert num_rows == 4 def test_aggregate_parquet(): - runtime = RuntimeConfig() - config = SessionConfig().set('datafusion.execution.parquet.schema_force_view_types', 'true') - df_ctx = SessionContext(config, runtime) + df_ctx = SessionContext() ctx = DatafusionRayContext(df_ctx) df_ctx.register_parquet("tips", "examples/tips.parquet") record_batches = ctx.sql("select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker") - assert isinstance(record_batches, list) # TODO why does this return many empty batches? num_rows = 0 for record_batch in record_batches: