From 90062f61338ec572a1c440b42c9ad3fb9d46ebc9 Mon Sep 17 00:00:00 2001 From: John Bodley Date: Sat, 25 Mar 2017 12:10:31 -0700 Subject: [PATCH] [gather_columns] Traversing function calls --- examples/gather_columns.py | 25 +++++++++++++++++++------ test/test_presto_tests.py | 2 +- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/examples/gather_columns.py b/examples/gather_columns.py index 9e98021..85fb617 100644 --- a/examples/gather_columns.py +++ b/examples/gather_columns.py @@ -16,8 +16,9 @@ from lacquer import parser from lacquer.tree import AliasedRelation -from lacquer.tree import Join from lacquer.tree import DefaultTraversalVisitor +from lacquer.tree import FunctionCall +from lacquer.tree import Join from lacquer.tree import QualifiedNameReference from lacquer.tree import SingleColumn from lacquer.tree import Table @@ -48,13 +49,25 @@ def visit_query_specification(self, node, context): self.tables.append(node.from_) self.tables.reverse() + def get_all_qualified(expression, qualified=None): + if qualified is None: + qualified = [] + + if isinstance(expression, QualifiedNameReference): + qualified.append(expression) + elif isinstance(expression, FunctionCall): + for argument in expression.arguments: + get_all_qualified(argument, qualified) + + return qualified + def print_column_resolution_order(columns, tables): table_columns = [] tables_and_aliases = OrderedDict() for i in range(len(columns)): column = columns[i] - if isinstance(column.expression, QualifiedNameReference): - table_columns.append((column, i)) + for qualified in get_all_qualified(column.expression): + table_columns.append((qualified, i)) for table in tables: if isinstance(table, AliasedRelation): @@ -67,8 +80,7 @@ def print_column_resolution_order(columns, tables): print("\nTable Column Resolution:") for (column, position) in table_columns: - names = column.expression.name.parts - column_name = names[-1] + names = column.name.parts resolution = [] if len(names) > 1: qualified_table_name = ".".join(names[:-1]) @@ -114,7 +126,8 @@ def visit_subquery_expression(self, node, context): check_extracted_columns("select (select 1 from foo), a " "from c join d using(foo) join e using (bar)") check_extracted_columns("select 1, 20+a from c join d using(foo) join e using (bar)") - + check_extracted_columns("select sum(foo) from a", True) + check_extracted_columns("select concat(concat(foo)) from a", True) print("Running subquery checkers\n\n") check_has_subquery("select a from b") check_has_subquery("select a from (select a from b)") diff --git a/test/test_presto_tests.py b/test/test_presto_tests.py index 8286f57..9c0ed3c 100644 --- a/test/test_presto_tests.py +++ b/test/test_presto_tests.py @@ -332,4 +332,4 @@ def select_list_with_items(*args): def simple_query(select, from_=None): - return Query(query_body=QuerySpecification(select=select, from_=from_)) \ No newline at end of file + return Query(query_body=QuerySpecification(select=select, from_=from_))