diff --git a/demos/askem-var.py b/demos/askem-var.py index 71c1d86d..451e1597 100644 --- a/demos/askem-var.py +++ b/demos/askem-var.py @@ -49,9 +49,8 @@ class Variable(Schema): file_path = "testdata/askem-tiny/" if run_pz: - # reference, plan, stats = run_workload() - df_input = pd.DataFrame(dict_of_excerpts) - excerpts = Dataset(dataset, schema=Papersnippet) + df_input = pd.DataFrame(list_of_strings) + excerpts = Dataset(df_input, schema=Papersnippet) output = excerpts.convert( Variable, desc="A variable used or introduced in the context", cardinality=Cardinality.ONE_TO_MANY ).filter("The value name is 'a'", depends_on="name") @@ -64,6 +63,8 @@ class Variable(Schema): execution_strategy="sequential", optimizer_strategy="pareto", ) + + # Option 1: Use QueryProcessorFactory to create a processor and generate a plan processor = QueryProcessorFactory.create_processor(excerpts, config) plan = processor.generate_plan(output, policy) print(processor.plan) @@ -71,7 +72,7 @@ class Variable(Schema): with st.container(): st.write("### Executed plan: \n") # st.write(" " + str(plan).replace("\n", " \n ")) - for idx, op in enumerate(processor.plan.operators): + for idx, op in enumerate(plan.operators): strop = f"{idx + 1}. {str(op)}" strop = strop.replace("\n", " \n") st.write(strop) @@ -85,7 +86,7 @@ class Variable(Schema): start_time = time.time() # for idx, (vars, plan, stats) in enumerate(iterable): for idx, record in enumerate(input_records): - print(f"idx: {idx}\n vars: {vars}") + print(f"idx: {idx}\n record: {record}") index = idx vars = processor.execute_opstream(processor.plan, record) if idx == len(input_records) - 1: @@ -130,8 +131,8 @@ class Variable(Schema): st.write(" **value:** ", var.value, "\n") # write variables to a json file with readable format - # with open(f"askem-variables-{dataset}.json", "w") as f: - # json.dump(variables, f, indent=4) + with open(f"askem-variables-{dataset}.json", "w") as f: + json.dump(variables, f, indent=4) vars_df = pd.DataFrame(variables) # G = nx.DiGraph() diff --git a/demos/bdf-suite.py b/demos/bdf-suite.py index 4989a951..b78d0fcd 100644 --- a/demos/bdf-suite.py +++ b/demos/bdf-suite.py @@ -198,7 +198,7 @@ def extract_references(processing_strategy, execution_strategy, optimizer_strate run_pz = st.button("Run Palimpzest on dataset") # st.radio("Biofabric Data Integration") -run_pz = False +run_pz = True dataset = "bdf-usecase3-tiny" if run_pz: @@ -220,24 +220,27 @@ def extract_references(processing_strategy, execution_strategy, optimizer_strate execution_strategy="sequential", optimizer_strategy="pareto", ) - iterable = output.run(config) - + data_record_collection = output.run(config) + references = [] statistics = [] - for idx, (reference, plan, stats) in enumerate(iterable): + for idx, record_collection in enumerate(data_record_collection): record_time = time.time() + stats = record_collection.plan_stats + references = record_collection.data_records + plan = record_collection.executed_plans[0] statistics.append(stats) if not idx: with st.container(): st.write("### Executed plan: \n") - # st.write(" " + str(plan).replace("\n", " \n ")) - for idx, op in enumerate(plan.operators): - strop = f"{idx+1}. {str(op)}" - strop = strop.replace("\n", " \n") - st.write(strop) - for ref in reference: + st.write(" " + str(plan).replace("\n", " \n ")) + # for idx, op in enumerate(stats.plan_strs[0].operators): + # strop = f"{idx+1}. {str(op)}" + # strop = strop.replace("\n", " \n") + # st.write(strop) + for ref in references: try: index = ref.index except Exception: diff --git a/demos/bdf-usecase3.py b/demos/bdf-usecase3.py index 20a2c555..5154b003 100644 --- a/demos/bdf-usecase3.py +++ b/demos/bdf-usecase3.py @@ -18,6 +18,7 @@ from palimpzest.datamanager.datamanager import DataDirectory from palimpzest.policy import MaxQuality, MinCost from palimpzest.query.processor.config import QueryProcessorConfig +from palimpzest.query.processor.query_processor_factory import QueryProcessorFactory from palimpzest.sets import Dataset if not os.environ.get("OPENAI_API_KEY"): @@ -75,11 +76,15 @@ def run_workload(): tables = [] statistics = [] - for table, plan, stats in iterable: # noqa: B007 + for data_record_collection in iterable: # noqa: B007 # record_time = time.time() + table = data_record_collection.data_records + stats = data_record_collection.plan_stats tables += table statistics.append(stats) + processor = QueryProcessorFactory.create_processor(output, config) + plan = processor.generate_plan(output, policy) return tables, plan, stats @@ -115,24 +120,26 @@ def run_workload(): execution_strategy="sequential", optimizer_strategy="pareto", ) + processor = QueryProcessorFactory.create_processor(output, config) + plan =processor.generate_plan(output, policy) iterable = output.run(config) references = [] statistics = [] - for idx, (reference, plan, stats) in enumerate(iterable): + for idx, data_record_collection in enumerate(iterable): record_time = time.time() + references = data_record_collection.data_records + stats = data_record_collection.plan_stats + plan = data_record_collection.executed_plans[0] statistics.append(stats) if not idx: with st.container(): st.write("### Executed plan: \n") - # st.write(" " + str(plan).replace("\n", " \n ")) - for idx, op in enumerate(plan.operators): - strop = f"{idx+1}. {str(op)}" - strop = strop.replace("\n", " \n") - st.write(strop) - for ref in reference: + st.write(" " + str(plan).replace("\n", " \n ")) + + for ref in references: try: index = ref.index except Exception: diff --git a/demos/biofabric-demo-matching.ipynb b/demos/biofabric-demo-matching.ipynb index 2779901d..fe401fc9 100644 --- a/demos/biofabric-demo-matching.ipynb +++ b/demos/biofabric-demo-matching.ipynb @@ -163,9 +163,9 @@ " nocache=True,\n", " processing_strategy=\"no_sentinel\",\n", ")\n", - "records, plan, stats = output.run(config)\n", + "data_record_collection = output.run(config)\n", "\n", - "print_tables(records)" + "print_tables(data_record_collection.data_records)" ] }, { @@ -229,9 +229,9 @@ " nocache=True,\n", " processing_strategy=\"no_sentinel\",\n", ")\n", - "tables, plan, stats = patient_tables.run(config)\n", + "data_record_collection = patient_tables.run(config)\n", "\n", - "for table in tables:\n", + "for table in data_record_collection.data_records:\n", " header = table.header\n", " subset_rows = table.rows[:3]\n", "\n", @@ -241,7 +241,7 @@ " print(\" | \".join(row)[:100], \"...\")\n", " print()\n", "\n", - "print(stats)" + "print(data_record_collection.execution_stats)" ] }, { @@ -287,9 +287,9 @@ " processing_strategy=\"no_sentinel\",\n", " execution_strategy=\"pipelined_parallel\",\n", ")\n", - "tables, plan, stats = patient_tables.run(config)\n", + "data_record_collection = patient_tables.run(config)\n", "\n", - "for table in tables:\n", + "for table in data_record_collection.data_records:\n", " header = table.header\n", " subset_rows = table.rows[:3]\n", "\n", @@ -299,7 +299,7 @@ " print(\" | \".join(row)[:100], \"...\")\n", " print()\n", "\n", - "print(stats)" + "print(data_record_collection.execution_stats)" ] }, { @@ -330,8 +330,8 @@ ], "source": [ "print(\"Chosen plan:\")\n", - "print(plan, \"\\n\")\n", - "print(\"Stats:\", stats)" + "print(data_record_collection.executed_plans, \"\\n\")\n", + "print(\"Stats:\", data_record_collection.execution_stats)" ] }, { @@ -606,10 +606,10 @@ " processing_strategy=\"no_sentinel\",\n", " execution_strategy=\"pipelined_parallel\",\n", ")\n", - "matched_tables, plan, stats = case_data.run(config) \n", + "data_record_collection = case_data.run(config) \n", "\n", "output_rows = []\n", - "for output_table in matched_tables:\n", + "for output_table in data_record_collection.data_records:\n", " output_rows.append(output_table.to_dict()) \n", "\n", "output_df = pd.DataFrame(output_rows)\n", @@ -650,8 +650,8 @@ } ], "source": [ - "print(plan, \"\\n\")\n", - "print(\"Stats:\", stats)" + "print(data_record_collection.executed_plans, \"\\n\")\n", + "print(\"Stats:\", data_record_collection.execution_stats)" ] }, { @@ -903,10 +903,10 @@ " processing_strategy=\"no_sentinel\",\n", " execution_strategy=\"pipelined_parallel\",\n", ")\n", - "matched_tables, plan, stats = case_data.run(config)\n", + "data_record_collection = case_data.run(config)\n", "\n", "output_rows = []\n", - "for output_table in matched_tables:\n", + "for output_table in data_record_collection.data_records:\n", " output_rows.append(output_table.to_dict()) \n", "\n", "output_df = pd.DataFrame(output_rows)\n", @@ -946,8 +946,8 @@ } ], "source": [ - "print(plan, \"\\n\")\n", - "print(\"Stats:\", \"\")" + "print(data_record_collection.executed_plans, \"\\n\")\n", + "print(\"Stats:\", data_record_collection.execution_stats)" ] }, { @@ -1128,8 +1128,8 @@ "iterable = case_data.run(config)\n", "\n", "output_rows = []\n", - "for matched_tables, plan, stats in iterable: # noqa: B007\n", - " for output_table in matched_tables:\n", + "for data_record_collection in iterable: # noqa: B007\n", + " for output_table in data_record_collection.data_records:\n", " print(output_table.to_dict().keys())\n", " output_rows.append(output_table.to_dict()) \n", "\n", @@ -1619,8 +1619,8 @@ "iterable = case_data.run(config)\n", "\n", "output_rows = []\n", - "for matched_tables, plan, stats in iterable: # noqa: B007\n", - " for output_table in matched_tables:\n", + "for data_record_collection in iterable: # noqa: B007\n", + " for output_table in data_record_collection.data_records:\n", " print(output_table.to_dict().keys())\n", " output_rows.append(output_table.to_dict()) \n", "\n", diff --git a/demos/biofabric-demo.py b/demos/biofabric-demo.py index 6d634f62..e8e42c1d 100644 --- a/demos/biofabric-demo.py +++ b/demos/biofabric-demo.py @@ -136,11 +136,11 @@ def print_table(output): processing_strategy="no_sentinel", execution_strategy=executor, ) - tables, plan, stats = output.run(config) + data_record_collection = output.run(config) - print_table(tables) - print(plan) - print(stats) + print_table(data_record_collection.data_records) + print(data_record_collection.executed_plans) + # print(data_record_collection.execution_stats) end_time = time.time() print("Elapsed time:", end_time - start_time) diff --git a/demos/demo_core.py b/demos/demo_core.py index d48b470e..87f57f62 100644 --- a/demos/demo_core.py +++ b/demos/demo_core.py @@ -192,14 +192,14 @@ def execute_task(task, datasetid, policy, verbose=False, profile=False, processi execution_strategy=execution_strategy, optimizer_strategy=optimizer_strategy, ) - records, execution_stats = root_set.run(config) + data_record_collection = root_set.run(config) if profile: os.makedirs("profiling-data", exist_ok=True) with open(stat_path, "w") as f: - json.dump(execution_stats.to_json(), f) + json.dump(data_record_collection.execution_stats.to_json(), f) - return records, execution_stats, cols + return data_record_collection.data_records, data_record_collection.execution_stats, cols def format_results_table(records: list[DataRecord], cols=None): """Format records as a table""" diff --git a/demos/df-newinterface.py b/demos/df-newinterface.py new file mode 100644 index 00000000..9e34d8d6 --- /dev/null +++ b/demos/df-newinterface.py @@ -0,0 +1,23 @@ +import pandas as pd + +import palimpzest as pz +from palimpzest.query.processor.config import QueryProcessorConfig + +df = pd.read_csv("testdata/enron-tiny.csv") +qr2 = pz.Dataset(df) +qr2 = qr2.add_columns({"sender": ("The email address of the sender", "string"), + "subject": ("The subject of the email", "string"),# + "date": ("The date the email was sent", "string")}) +qr3 = qr2.filter("It is an email").filter("It has Vacation in the subject") + +config = QueryProcessorConfig( + verbose=True, + execution_strategy="pipelined_parallel", +) + +output = qr3.run(config) +output_df = output.to_df() +print(output_df) + +output_df = output.to_df(project_cols=["sender", "subject", "date"]) +print(output_df) diff --git a/demos/fever-demo.py b/demos/fever-demo.py index 85406e54..dde69903 100644 --- a/demos/fever-demo.py +++ b/demos/fever-demo.py @@ -291,7 +291,7 @@ def get_item(self, idx: int, val: bool=False, include_label: bool=False): verbose=verbose, allow_code_synth=allow_code_synth ) -records, execution_stats = output.run(config) +data_record_collection = output.run(config) # create filepaths for records and stats records_path = ( @@ -306,7 +306,7 @@ def get_item(self, idx: int, val: bool=False, include_label: bool=False): ) record_jsons = [] -for record in records: +for record in data_record_collection.data_records: record_dict = record.to_dict() ### field_to_keep = ["claim", "id", "label"] ### record_dict = {k: v for k, v in record_dict.items() if k in fields_to_keep} @@ -316,6 +316,6 @@ def get_item(self, idx: int, val: bool=False, include_label: bool=False): json.dump(record_jsons, f) # save statistics -execution_stats_dict = execution_stats.to_json() +execution_stats_dict = data_record_collection.execution_stats.to_json() with open(stats_path, "w") as f: json.dump(execution_stats_dict, f) diff --git a/demos/image-demo.py b/demos/image-demo.py index c93e139f..6e02bbbc 100644 --- a/demos/image-demo.py +++ b/demos/image-demo.py @@ -36,8 +36,9 @@ def build_image_plan(dataset_id): if __name__ == "__main__": # parse arguments start_time = time.time() + parser = argparse.ArgumentParser(description="Run a simple demo") - parser.add_argument("--no-cache", action="store_true", help="Do not use cached results") + parser.add_argument("--no-cache", action="store_true", help="Do not use cached results", default=True) args = parser.parse_args() no_cache = args.no_cache @@ -57,11 +58,11 @@ def build_image_plan(dataset_id): verbose=True, processing_strategy="no_sentinel" ) - records, execution_stats = plan.run(config) + data_record_collection = plan.run(config) - print("Obtained records", records) + print("Obtained records", data_record_collection.data_records) imgs, breeds = [], [] - for record in records: + for record in data_record_collection.data_records: print("Trying to open ", record.filename) path = os.path.join("testdata/images-tiny/", record.filename) img = Image.open(path).resize((128, 128)) @@ -78,7 +79,7 @@ def build_image_plan(dataset_id): with gr.Column(): breed_blocks.append(gr.Textbox(value=breed)) - plan_str = list(execution_stats.plan_strs.values())[0] + plan_str = list(data_record_collection.execution_stats.plan_strs.values())[0] gr.Textbox(value=plan_str, info="Query Plan") end_time = time.time() diff --git a/demos/optimizer-demo.py b/demos/optimizer-demo.py index 88ee1000..f84a3d84 100644 --- a/demos/optimizer-demo.py +++ b/demos/optimizer-demo.py @@ -207,6 +207,7 @@ def __init__( self.listings_dir = listings_dir self.split_idx = split_idx self.listings = sorted(os.listdir(self.listings_dir), key=lambda listing: int(listing.split("listing")[-1])) + assert len(self.listings) > split_idx, "split_idx is greater than the number of listings" self.val_listings = self.listings[:split_idx] self.listings = self.listings[split_idx:] @@ -1019,7 +1020,7 @@ def get_item(self, idx: int, val: bool = False, include_label: bool = False): verbose=verbose, ) - records, execution_stats = plan.run( + data_record_collection = plan.run( config=config, k=k, j=j, @@ -1032,6 +1033,8 @@ def get_item(self, idx: int, val: bool = False, include_label: bool = False): exp_name=exp_name ) + print(data_record_collection.to_df()) + # create filepaths for records and stats records_path = ( f"opt-profiling-data/{workload}-{exp_name}-records.json" @@ -1046,7 +1049,7 @@ def get_item(self, idx: int, val: bool = False, include_label: bool = False): # save record outputs record_jsons = [] - for record in records: + for record in data_record_collection.data_records: record_dict = record.to_dict() if workload == "biodex": record_dict = { @@ -1064,6 +1067,6 @@ def get_item(self, idx: int, val: bool = False, include_label: bool = False): json.dump(record_jsons, f) # save statistics - execution_stats_dict = execution_stats.to_json() + execution_stats_dict = data_record_collection.execution_stats.to_json() with open(stats_path, "w") as f: json.dump(execution_stats_dict, f) diff --git a/demos/paper-demo.py b/demos/paper-demo.py index c08f0062..2f72ebc3 100644 --- a/demos/paper-demo.py +++ b/demos/paper-demo.py @@ -262,12 +262,13 @@ def get_item(self, idx: int): # records, execution_stats = processor.execute() # Option 2: Use Dataset.run() to run the plan. - records, execution_stats = plan.run(config) - + data_record_collection = plan.run(config) + print(data_record_collection.to_df()) # save statistics + if profile: stats_path = f"profiling-data/{workload}-profiling.json" - execution_stats_dict = execution_stats.to_json() + execution_stats_dict = data_record_collection.execution_stats.to_json() with open(stats_path, "w") as f: json.dump(execution_stats_dict, f) @@ -275,13 +276,13 @@ def get_item(self, idx: int): if visualize: from palimpzest.utils.demo_helpers import print_table - plan_str = list(execution_stats.plan_strs.values())[-1] + plan_str = list(data_record_collection.execution_stats.plan_strs.values())[-1] if workload == "enron": - print_table(records, cols=["sender", "subject"], plan_str=plan_str) + print_table(data_record_collection.data_records, cols=["sender", "subject"], plan_str=plan_str) elif workload == "real-estate": fst_imgs, snd_imgs, thrd_imgs, addrs, prices = [], [], [], [], [] - for record in records: + for record in data_record_collection.data_records: addrs.append(record.address) prices.append(record.price) for idx, img_name in enumerate(["img1.png", "img2.png", "img3.png"]): @@ -311,7 +312,7 @@ def get_item(self, idx: int): with gr.Column(): price_blocks.append(gr.Textbox(value=price, info="Price")) - plan_str = list(execution_stats.plan_strs.values())[0] + plan_str = list(data_record_collection.execution_stats.plan_strs.values())[0] gr.Textbox(value=plan_str, info="Query Plan") demo.launch() diff --git a/quickstart.ipynb b/quickstart.ipynb index 5e530a39..1dc49164 100644 --- a/quickstart.ipynb +++ b/quickstart.ipynb @@ -29,7 +29,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -71,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -103,7 +103,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -133,9 +133,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset Dataset(schema=, desc=An email from the Enron dataset, filter=None, udf=None, agg_func=None, limit=None, project_cols=None, uid=06a23b1a60)\n", + "The schema of the dataset is \n" + ] + } + ], "source": [ "print(\"Dataset\", dataset)\n", "print(\"The schema of the dataset is\", dataset.schema)" @@ -155,7 +164,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -193,7 +202,7 @@ " execution_strategy=\"sequential\",\n", " optimizer_strategy=\"pareto\",\n", ")\n", - "results, execution_stats = dataset.run(config)" + "data_record_collection = dataset.run(config)" ] }, { @@ -214,8 +223,9 @@ "source": [ "import pandas as pd\n", "\n", - "output_df = pd.DataFrame([r.to_dict() for r in results])[[\"date\",\"sender\",\"subject\"]]\n", - "display(output_df)\n" + "output_df = data_record_collection.to_df(project_cols=[\"date\", \"sender\", \"subject\"])\n", + "display(output_df)\n", + "\n" ] }, { @@ -234,6 +244,7 @@ "metadata": {}, "outputs": [], "source": [ + "execution_stats = data_record_collection.execution_stats\n", "print(\"Time to find an optimal plan:\", execution_stats.total_optimization_time,\"s\")\n", "print(\"Time to execute the plan:\", execution_stats.total_execution_time, \"s\")\n", "print(\"Total cost:\", execution_stats.total_execution_cost, \"USD\")\n", @@ -267,7 +278,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.11.11" } }, "nbformat": 4, diff --git a/src/palimpzest/core/elements/records.py b/src/palimpzest/core/elements/records.py index 1dab01c6..20fed655 100644 --- a/src/palimpzest/core/elements/records.py +++ b/src/palimpzest/core/elements/records.py @@ -6,7 +6,7 @@ import pandas as pd from palimpzest.constants import FROM_DF_PREFIX -from palimpzest.core.data.dataclasses import RecordOpStats +from palimpzest.core.data.dataclasses import ExecutionStats, PlanStats, RecordOpStats from palimpzest.core.lib.fields import Field from palimpzest.core.lib.schemas import Schema from palimpzest.utils.hash_helpers import hash_for_id @@ -250,13 +250,14 @@ def from_df(df: pd.DataFrame, schema: Schema | None = None, source_id: int | str return records @staticmethod - def to_df(records: list[DataRecord], fields_in_schema: bool = False) -> pd.DataFrame: + def to_df(records: list[DataRecord], project_cols: list[str] | None = None) -> pd.DataFrame: if len(records) == 0: return pd.DataFrame() - if not fields_in_schema: - return pd.DataFrame([record.to_dict() for record in records]) - + fields = records[0].schema.field_names() + if project_cols is not None and len(project_cols) > 0: + fields = [field for field in fields if field in project_cols] + return pd.DataFrame([ {k: record[k] for k in fields} for record in records @@ -324,3 +325,45 @@ def __len__(self): def __iter__(self): yield from self.data_records + + +class DataRecordCollection: + """ + A DataRecordCollection contains a list of DataRecords. + + This is a wrapper class for list[DataRecord] to support more advanced features for output of execute(). + + The difference between DataRecordSet and DataRecordCollection + Goal: + DataRecordSet is a set of DataRecords that share the same schema, same parent_id, and same source_id. + DataRecordCollection is a general wrapper for list[DataRecord]. + + Usage: + DataRecordSet is used for the output of executing an operator. + DataRecordCollection is used for the output of executing a query, we definitely could extend it to support more advanced features for output of execute(). + """ + # TODO(Jun): consider to have stats_manager class to centralize stats management. + def __init__(self, data_records: list[DataRecord], execution_stats: ExecutionStats | None = None, plan_stats: PlanStats | None = None): + self.data_records = data_records + self.execution_stats = execution_stats + self.plan_stats = plan_stats + self.executed_plans = self._get_executed_plans() + + def __iter__(self): + """Allow iterating directly over the data records""" + yield from self.data_records + + def __len__(self): + """Return the number of records in the collection""" + return len(self.data_records) + + def to_df(self, project_cols: list[str] | None = None): + return DataRecord.to_df(self.data_records, project_cols) + + def _get_executed_plans(self): + if self.plan_stats is not None: + return [self.plan_stats.plan_str] + elif self.execution_stats is not None: + return list(self.execution_stats.plan_strs.values()) + else: + return None diff --git a/src/palimpzest/core/lib/schemas.py b/src/palimpzest/core/lib/schemas.py index 7be67b10..f33788c6 100644 --- a/src/palimpzest/core/lib/schemas.py +++ b/src/palimpzest/core/lib/schemas.py @@ -274,6 +274,49 @@ def from_df(df: pd.DataFrame) -> Schema: # Store the schema class globally globals()[schema_name] = new_schema return new_schema + + @classmethod + def add_fields(cls, fields: dict[str, str]) -> Schema: + """Add fields to the schema + + Args: + fields: Dictionary mapping field names to their descriptions + + Returns: + A new Schema with the additional fields + """ + # Construct the new schema name + schema_name = cls.class_name() + new_schema_name = f"{schema_name}Extended" + + # Construct new schema description + new_desc = f"{cls.__doc__}\nExtended with additional fields" + + # Get existing fields + new_field_names = list(cls.field_names()) + new_field_types = list(cls.field_map().values()) + new_field_descs = [field._desc for field in new_field_types] + + # TODO: Users will provide explicit descriptions for the fields, + # details in https://github.com/mitdbg/palimpzest/issues/84 + for field_name, field_desc in fields.items(): + if field_name in new_field_names: + continue + new_field_names.append(field_name) + new_field_types.append(StringField(desc=field_desc)) # Assuming StringField for new fields + new_field_descs.append(field_desc) + + # Generate the schema class dynamically + attributes = {"_desc": new_desc, "__doc__": new_desc} + for field_name, field_type, field_desc in zip(new_field_names, new_field_types, new_field_descs): + attributes[field_name] = ( + field_type.__class__(desc=str(field_desc), element_type=field_type.element_type) + if isinstance(field_type, ListField) + else field_type.__class__(desc=str(field_desc)) + ) + + # Create the class dynamically + return type(new_schema_name, (Schema,), attributes) @classmethod def class_name(cls) -> str: diff --git a/src/palimpzest/query/processor/mab_sentinel_processor.py b/src/palimpzest/query/processor/mab_sentinel_processor.py index 5f8ba658..de3602a3 100644 --- a/src/palimpzest/query/processor/mab_sentinel_processor.py +++ b/src/palimpzest/query/processor/mab_sentinel_processor.py @@ -7,7 +7,7 @@ from palimpzest.constants import PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS from palimpzest.core.data.dataclasses import ExecutionStats, OperatorStats, PlanStats, RecordOpStats -from palimpzest.core.elements.records import DataRecord, DataRecordSet +from palimpzest.core.elements.records import DataRecord, DataRecordCollection, DataRecordSet from palimpzest.core.lib.schemas import SourceRecord from palimpzest.policy import Policy from palimpzest.query.execution.parallel_execution_strategy import PipelinedParallelExecutionStrategy @@ -475,18 +475,22 @@ def pick_highest_quality_output(self, op_set_record_sets: list[tuple[DataRecordS # compute highest quality answer at each index out_records = [] + out_record_op_stats = [] for idx in range(len(idx_to_records)): records_lst, record_op_stats_lst = zip(*idx_to_records[idx]) max_quality_record, max_quality = records_lst[0], record_op_stats_lst[0].quality + max_quality_stats = record_op_stats_lst[0] for record, record_op_stats in zip(records_lst[1:], record_op_stats_lst[1:]): record_quality = record_op_stats.quality if record_quality > max_quality: max_quality_record = record max_quality = record_quality + max_quality_stats = record_op_stats out_records.append(max_quality_record) + out_record_op_stats.append(max_quality_stats) # create and return final DataRecordSet - return DataRecordSet(out_records, []) + return DataRecordSet(out_records, out_record_op_stats) def execute_op_set(self, op_candidate_pairs): @@ -782,7 +786,7 @@ def create_sentinel_plan(self, dataset: Set, policy: Policy) -> SentinelPlan: return sentinel_plan - def execute(self): + def execute(self) -> DataRecordCollection: execution_start_time = time.time() # for now, enforce that we are using validation data; we can relax this after paper submission @@ -829,7 +833,7 @@ def execute(self): plan_strs={plan_id: plan_stats.plan_str for plan_id, plan_stats in aggregate_plan_stats.items()}, ) - return all_records, execution_stats + return DataRecordCollection(all_records, execution_stats = execution_stats) diff --git a/src/palimpzest/query/processor/nosentinel_processor.py b/src/palimpzest/query/processor/nosentinel_processor.py index e795cc0c..e84c05dc 100644 --- a/src/palimpzest/query/processor/nosentinel_processor.py +++ b/src/palimpzest/query/processor/nosentinel_processor.py @@ -1,7 +1,7 @@ import time from palimpzest.core.data.dataclasses import ExecutionStats, OperatorStats, PlanStats -from palimpzest.core.elements.records import DataRecord +from palimpzest.core.elements.records import DataRecord, DataRecordCollection from palimpzest.core.lib.schemas import SourceRecord from palimpzest.query.execution.parallel_execution_strategy import PipelinedParallelExecutionStrategy from palimpzest.query.execution.single_threaded_execution_strategy import ( @@ -15,7 +15,7 @@ from palimpzest.query.optimizer.plan import PhysicalPlan from palimpzest.query.processor.query_processor import QueryProcessor from palimpzest.utils.progress import create_progress_manager - + class NoSentinelQueryProcessor(QueryProcessor): """ @@ -24,7 +24,7 @@ class NoSentinelQueryProcessor(QueryProcessor): """ # TODO: Consider to support dry_run. - def execute(self): + def execute(self) -> DataRecordCollection: execution_start_time = time.time() # if nocache is True, make sure we do not re-use codegen examples @@ -48,7 +48,7 @@ def execute(self): plan_strs={plan_id: plan_stats.plan_str for plan_id, plan_stats in aggregate_plan_stats.items()}, ) - return records, execution_stats + return DataRecordCollection(records, execution_stats=execution_stats) class NoSentinelSequentialSingleThreadProcessor(NoSentinelQueryProcessor, SequentialSingleThreadExecutionStrategy): diff --git a/src/palimpzest/query/processor/query_processor.py b/src/palimpzest/query/processor/query_processor.py index 91cddbb0..d68d11c0 100644 --- a/src/palimpzest/query/processor/query_processor.py +++ b/src/palimpzest/query/processor/query_processor.py @@ -3,7 +3,7 @@ from palimpzest.core.data.dataclasses import PlanStats, RecordOpStats from palimpzest.core.data.datasources import DataSource, ValidationDataSource -from palimpzest.core.elements.records import DataRecord +from palimpzest.core.elements.records import DataRecord, DataRecordCollection from palimpzest.datamanager.datamanager import DataDirectory from palimpzest.policy import Policy from palimpzest.query.optimizer.cost_model import CostModel @@ -271,5 +271,5 @@ def _execute_confidence_interval_strategy( # TODO: consider to support dry_run. @abstractmethod - def execute(self): - raise NotImplementedError("Abstract method to be overwritten by sub-classes") + def execute(self) -> DataRecordCollection: + raise NotImplementedError("Abstract method to be overwritten by sub-classes") \ No newline at end of file diff --git a/src/palimpzest/query/processor/query_processor_factory.py b/src/palimpzest/query/processor/query_processor_factory.py index 1b5f5079..5ad37711 100644 --- a/src/palimpzest/query/processor/query_processor_factory.py +++ b/src/palimpzest/query/processor/query_processor_factory.py @@ -1,5 +1,6 @@ from enum import Enum +from palimpzest.core.elements.records import DataRecordCollection from palimpzest.query.execution.execution_strategy import ExecutionStrategyType from palimpzest.query.optimizer.cost_model import CostModel from palimpzest.query.optimizer.optimizer import Optimizer @@ -95,7 +96,7 @@ def create_processor( return processor_cls(dataset=dataset, optimizer=optimizer, config=config, **kwargs) @classmethod - def create_and_run_processor(cls, dataset: Dataset, config: QueryProcessorConfig, **kwargs): + def create_and_run_processor(cls, dataset: Dataset, config: QueryProcessorConfig, **kwargs) -> DataRecordCollection: # TODO(Jun): Consider to use cache here. processor = cls.create_processor(dataset=dataset, config=config, **kwargs) return processor.execute() diff --git a/src/palimpzest/query/processor/random_sampling_sentinel_processor.py b/src/palimpzest/query/processor/random_sampling_sentinel_processor.py index e04a7802..bf8ffc48 100644 --- a/src/palimpzest/query/processor/random_sampling_sentinel_processor.py +++ b/src/palimpzest/query/processor/random_sampling_sentinel_processor.py @@ -8,7 +8,7 @@ from palimpzest.constants import PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS from palimpzest.core.data.dataclasses import ExecutionStats, OperatorStats, PlanStats, RecordOpStats from palimpzest.core.data.datasources import ValidationDataSource -from palimpzest.core.elements.records import DataRecord, DataRecordSet +from palimpzest.core.elements.records import DataRecord, DataRecordCollection, DataRecordSet from palimpzest.core.lib.schemas import SourceRecord from palimpzest.policy import Policy from palimpzest.query.execution.parallel_execution_strategy import PipelinedParallelExecutionStrategy @@ -525,7 +525,7 @@ def create_sentinel_plan(self, dataset: Set, policy: Policy) -> SentinelPlan: return sentinel_plan - def execute(self): + def execute(self) -> DataRecordCollection: execution_start_time = time.time() # for now, enforce that we are using validation data; we can relax this after paper submission @@ -570,7 +570,7 @@ def execute(self): plan_strs={plan_id: plan_stats.plan_str for plan_id, plan_stats in aggregate_plan_stats.items()}, ) - return all_records, execution_stats + return DataRecordCollection(all_records, execution_stats=execution_stats) class RandomSamplingSentinelSequentialSingleThreadProcessor(RandomSamplingSentinelQueryProcessor, SequentialSingleThreadExecutionStrategy): diff --git a/src/palimpzest/query/processor/streaming_processor.py b/src/palimpzest/query/processor/streaming_processor.py index 793eb6d5..1de0453f 100644 --- a/src/palimpzest/query/processor/streaming_processor.py +++ b/src/palimpzest/query/processor/streaming_processor.py @@ -1,7 +1,7 @@ import time from palimpzest.core.data.dataclasses import OperatorStats, PlanStats -from palimpzest.core.elements.records import DataRecord +from palimpzest.core.elements.records import DataRecord, DataRecordCollection from palimpzest.core.lib.schemas import SourceRecord from palimpzest.policy import Policy from palimpzest.query.operators.aggregate import AggregateOp @@ -85,8 +85,8 @@ def execute(self): if idx == len(input_records) - 1: total_plan_time = time.time() - start_time self.plan_stats.finalize(total_plan_time) - - yield output_records, self.plan, self.plan_stats + self.plan_stats.plan_str = str(self.plan) + yield DataRecordCollection(output_records, plan_stats=self.plan_stats) def get_input_records(self): scan_operator = self.plan.operators[0] diff --git a/src/palimpzest/sets.py b/src/palimpzest/sets.py index b83ead07..fbf67108 100644 --- a/src/palimpzest/sets.py +++ b/src/palimpzest/sets.py @@ -139,11 +139,11 @@ class Dataset(Set): def __init__(self, source: str | list | pd.DataFrame | DataSource, schema: Schema | None = None, *args, **kwargs): # convert source (str) -> source (DataSource) if need be - source = DataDirectory().get_or_register_dataset(source) if isinstance(source, (str, list, pd.DataFrame)) else source + updated_source = DataDirectory().get_or_register_dataset(source) if isinstance(source, (str, list, pd.DataFrame)) else source if schema is None: schema = Schema.from_df(source) if isinstance(source, pd.DataFrame) else DefaultSchema # intialize class - super().__init__(source, schema, *args, **kwargs) + super().__init__(updated_source, schema, *args, **kwargs) def copy(self) -> Dataset: source_copy = self._source.copy() @@ -212,6 +212,11 @@ def convert( desc=desc, nocache=self._nocache, ) + + # This is a convenience for users who like DataFrames-like syntax. + def add_columns(self, columns:dict[str, str], cardinality: Cardinality = Cardinality.ONE_TO_ONE) -> Dataset: + new_output_schema = self.schema.add_fields(columns) + return self.convert(new_output_schema, udf=None, cardinality=cardinality, depends_on=None, desc="Add columns " + str(columns)) def count(self) -> Dataset: """Apply a count aggregation to this set""" diff --git a/tests/pytest/test_dynamicschema.py b/tests/pytest/test_dynamicschema.py index de798a0a..5c6e67b6 100644 --- a/tests/pytest/test_dynamicschema.py +++ b/tests/pytest/test_dynamicschema.py @@ -42,9 +42,9 @@ def test_dynamicschema_json(mocker, enron_workload, enron_convert, enron_filter) execution_strategy="sequential", optimizer_strategy="pareto", ) - records, stats = enron_workload.run(config=config) + data_record_collection = enron_workload.run(config=config) - for rec in records: + for rec in data_record_collection.data_records: print(rec.to_dict()) @@ -71,7 +71,7 @@ def test_dynamicschema_yml(mocker, enron_workload, enron_convert, enron_filter): execution_strategy="sequential", optimizer_strategy="pareto", ) - records, stats = enron_workload.run(config=config) + data_record_collection = enron_workload.run(config=config) - for rec in records: + for rec in data_record_collection.data_records: print(rec.to_dict()) diff --git a/tests/pytest/test_records.py b/tests/pytest/test_records.py index c4113027..4549c9f7 100644 --- a/tests/pytest/test_records.py +++ b/tests/pytest/test_records.py @@ -53,7 +53,13 @@ def test_to_df(self, sample_df): """Test converting records back to DataFrame""" records = DataRecord.from_df(sample_df, schema=TestSchema) df_result = DataRecord.to_df(records) - pd.testing.assert_frame_equal(df_result, sample_df) + assert df_result.equals(sample_df) + + def test_to_df_with_project_cols(self, sample_df): + """Test converting records to DataFrame with project_cols""" + records = DataRecord.from_df(sample_df, schema=TestSchema) + df_result = DataRecord.to_df(records, project_cols=["name"]) + assert df_result.equals(sample_df[["name"]]) def test_derived_schema(self, sample_df): """Test auto-schema generation from DataFrame""" diff --git a/tests/pytest/test_schemas.py b/tests/pytest/test_schemas.py index 30e70e7b..8dff02f9 100644 --- a/tests/pytest/test_schemas.py +++ b/tests/pytest/test_schemas.py @@ -13,3 +13,19 @@ class Cat(Schema): def test_schema_equality(): assert Dog == Dog assert Dog != Cat + + +def test_schema_add_fields(): + dog_extended = Dog.add_fields({"color": "The color of the dog"}) + assert sorted(dog_extended.field_names()) == ["breed", "color", "is_good"] + assert dog_extended.field_map()["color"] == StringField(desc="The color of the dog") + + # Add the same field again, should be skipped + dog_extended2 = dog_extended.add_fields({"color": "The color of the dog"}) + assert sorted(dog_extended2.field_names()) == ["breed", "color", "is_good"] + assert dog_extended2.field_map()["color"] == StringField(desc="The color of the dog") + +def test_schema_add_fields_with_existing_fields(): + dog_extended = Dog.add_fields({"breed": "The breed of the dog"}) + assert sorted(dog_extended.field_names()) == ["breed", "is_good"] + assert dog_extended.field_map()["breed"] == StringField(desc="The breed of the dog") \ No newline at end of file