The PyTorch implementation of paper SADGA: Structure-Aware Dual Graph Aggregation Network for Text-to-SQL. (NeurIPS 2021)
If you use SADGA in your work, please cite it as follows:
@article{cai2021sadga,
title={SADGA: Structure-Aware Dual Graph Aggregation Network for Text-to-SQL},
author={Cai, Ruichu and Yuan, Jinjie and Xu, Boyan and Hao, Zhifeng},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}
mkdir -p dataset third_party plm/gap
Download the dataset: Spider. Then unzip spider.zip
into the directory dataset
.
└── dataset
├── database
│ ├── academic
│ │ ├──academic.sqlite
│ │ ├──schema.sql
│ ├── ...
├── dev_gold.sql
├── dev.json
├── README.txt
├── tables.json
├── train_gold.sql
├── train_others.json
└── train_spider.json
Download and unzip Stanford CoreNLP to the directory third_party
. Note that this repository requires a JVM to run it.
└── third_party
└── stanford-corenlp-full-2018-10-05
├── ...
Regarding the implementation with the pretrained model GAP (Learning Contextual Representations for Semantic Parsing with Generation-Augmented Pre-Training), download the pertained model from pretrained-checkpoint into the directory plm/gap
.
└── plm
└── gap
└── pretrained-checkpoint
We trained our models on one server with a single NVIDIA GTX 3090 GPU with 24GB GPU memory. In our experiments, we use python 3.7, torch 1.7.1 with CUDA version 11.0. We create conda environment sadgasql
:
conda create -n sadgasql python=3.7
source activate sadgasql
pip install torch==1.7.1+cu110 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt
python -c "import nltk; nltk.download('stopwords'); nltk.download('punkt')"
All configs of the experiments and models are in the files sadga-glove-run.jsonnet
, sadga-bert-run.jsonnet
, sadga-gap-run.jsonnet
.
python run.py --mode preprocess --config sadga-[glove|bert|gap]-run.jsonnet
- Our preprocessed dataset can be downloaded here (5.0M).
python run.py --mode train --config sadga-[glove|bert|gap]-run.jsonnet
- After the training, we can obtain some model-checkpoints in the directory
{logdir}/{model_name}/
, e.g.,logdir/sadga_glove_bs=20_lr=7.4e-04/model_checkpoint-00020100
.
python run.py --mode infer --config sadga-[glove|bert|gap]-run.jsonnet
- The inference phase aims to output the predicted SQL file
predict_sql_step{xxx}.txt
(the same input format as the official Spider Evaluation) in the directory{logdir}/{model_name}/{res_dir}
for each saved models, e.g.,logdir/sadga_glove_bs=20_lr=7.4e-04/res/predict_sql_step20100.txt
.
python run.py --mode eval --config sadga-[glove|bert|gap]-run.jsonnet
- The Spider's official evaluation. We can get the final detailed accuracy result file
acc_res_step{xxx}.txt
for each saved models, e.g.,logdir/sadga_glove_bs=20_lr=7.4e-04/res/acc_res_step20100.txt
, and the program can print the all inferred steps results as:
STEP ACCURACY
10100 0.544
11100 0.560
...
40000 0.652
Best Result:
38100 0.656
Our best trained checkpoints, predict_sql files and accuracy_result files can be downloaded in here:
[logdir.zip] (GloVe: 65.6) [logdir.zip] (Bert-large: 71.6) [logdir.zip] (GAP: 73.9)
Model | Exact Match Acc (Dev) | Exact Match Acc (Test) |
---|---|---|
SADGA + GloVe | 64.7 (65.6 this repo) | - |
SADGA + Bert-large | 71.6 | 66.7 |
SADGA + GAP | 73.1 (73.9 this repo) | 70.1 |
Detailed results can be found in the paper. Note that the Spider official has not released the test set, and the results on the test set are only available by submitting the model to the official evaluation.
This implementation is based on the ACL 2020 paper RAT-SQL: Relation-Aware Schema Encoding and Linking for Text-to-SQL Parsers (code) . Thanks to the open-source project. We thank Tao Yu and Yusen Zhang for their evaluation of our work in the Spider Challenge. We also thank the anonymous reviewers for their helpful comments.