Skip to content

Commit

Permalink
[gather_columns] Traversing function calls
Browse files Browse the repository at this point in the history
  • Loading branch information
John Bodley committed Mar 26, 2017
1 parent 33f6083 commit 90062f6
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
25 changes: 19 additions & 6 deletions examples/gather_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

from lacquer import parser
from lacquer.tree import AliasedRelation
from lacquer.tree import Join
from lacquer.tree import DefaultTraversalVisitor
from lacquer.tree import FunctionCall
from lacquer.tree import Join
from lacquer.tree import QualifiedNameReference
from lacquer.tree import SingleColumn
from lacquer.tree import Table
Expand Down Expand Up @@ -48,13 +49,25 @@ def visit_query_specification(self, node, context):
self.tables.append(node.from_)
self.tables.reverse()

def get_all_qualified(expression, qualified=None):
if qualified is None:
qualified = []

if isinstance(expression, QualifiedNameReference):
qualified.append(expression)
elif isinstance(expression, FunctionCall):
for argument in expression.arguments:
get_all_qualified(argument, qualified)

return qualified

def print_column_resolution_order(columns, tables):
table_columns = []
tables_and_aliases = OrderedDict()
for i in range(len(columns)):
column = columns[i]
if isinstance(column.expression, QualifiedNameReference):
table_columns.append((column, i))
for qualified in get_all_qualified(column.expression):
table_columns.append((qualified, i))

for table in tables:
if isinstance(table, AliasedRelation):
Expand All @@ -67,8 +80,7 @@ def print_column_resolution_order(columns, tables):

print("\nTable Column Resolution:")
for (column, position) in table_columns:
names = column.expression.name.parts
column_name = names[-1]
names = column.name.parts
resolution = []
if len(names) > 1:
qualified_table_name = ".".join(names[:-1])
Expand Down Expand Up @@ -114,7 +126,8 @@ def visit_subquery_expression(self, node, context):
check_extracted_columns("select (select 1 from foo), a "
"from c join d using(foo) join e using (bar)")
check_extracted_columns("select 1, 20+a from c join d using(foo) join e using (bar)")

check_extracted_columns("select sum(foo) from a", True)
check_extracted_columns("select concat(concat(foo)) from a", True)
print("Running subquery checkers\n\n")
check_has_subquery("select a from b")
check_has_subquery("select a from (select a from b)")
Expand Down
2 changes: 1 addition & 1 deletion test/test_presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,4 +332,4 @@ def select_list_with_items(*args):


def simple_query(select, from_=None):
return Query(query_body=QuerySpecification(select=select, from_=from_))
return Query(query_body=QuerySpecification(select=select, from_=from_))

0 comments on commit 90062f6

Please sign in to comment.