Skip to content

YawenTsou/Open-IE-stock-prediction

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Event Extraction (SVO)\n",
    "#### parameters : news(.pkl) / output_path(.pkl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>title</th>\n",
       "      <th>date</th>\n",
       "      <th>source</th>\n",
       "      <th>document_body</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>Railcars in North Dakota crude train crash old...</td>\n",
       "      <td>20140101</td>\n",
       "      <td>Reuters News</td>\n",
       "      <td>* Railcars don't meet industry's latest safety...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>S.Korea Dec crude oil imports down 8.4 pct y/y...</td>\n",
       "      <td>20140101</td>\n",
       "      <td>Reuters News</td>\n",
       "      <td>SEOUL, Jan 1 (Reuters) - South Korea's crude o...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>Doctors, hospitals expect some confusion as Ob...</td>\n",
       "      <td>20140101</td>\n",
       "      <td>Reuters News</td>\n",
       "      <td>Jan 1 (Reuters) - Hospitals and medical practi...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>Justice Blocks Contraception Mandate on Insura...</td>\n",
       "      <td>20140101</td>\n",
       "      <td>The New York Times</td>\n",
       "      <td>WASHINGTON -- Justice Sonia Sotomayor on Tuesd...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>For Stocks, an Amazingly Good Year</td>\n",
       "      <td>20140101</td>\n",
       "      <td>The New York Times</td>\n",
       "      <td>It was the market rally that defied gravity an...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id                                              title      date  \\\n",
       "0   0  Railcars in North Dakota crude train crash old...  20140101   \n",
       "1   1  S.Korea Dec crude oil imports down 8.4 pct y/y...  20140101   \n",
       "2   2  Doctors, hospitals expect some confusion as Ob...  20140101   \n",
       "3   3  Justice Blocks Contraception Mandate on Insura...  20140101   \n",
       "4   4                 For Stocks, an Amazingly Good Year  20140101   \n",
       "\n",
       "               source                                      document_body  \n",
       "0        Reuters News  * Railcars don't meet industry's latest safety...  \n",
       "1        Reuters News  SEOUL, Jan 1 (Reuters) - South Korea's crude o...  \n",
       "2        Reuters News  Jan 1 (Reuters) - Hospitals and medical practi...  \n",
       "3  The New York Times  WASHINGTON -- Justice Sonia Sotomayor on Tuesd...  \n",
       "4  The New York Times  It was the market rally that defied gravity an...  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# input\n",
    "import pickle\n",
    "with open('./dataset/test.pkl', 'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████| 2/2 [00:00<00:00, 1193.77it/s]\n",
      "100%|█████████████████████████████████████████████| 2/2 [00:28<00:00, 14.29s/it]\n",
      "100%|█████████████████████████████████████████████| 2/2 [00:29<00:00, 14.71s/it]\n",
      "100%|█████████████████████████████████████████████| 2/2 [00:29<00:00, 14.70s/it]\n",
      "100%|█████████████████████████████████████████████| 2/2 [00:29<00:00, 14.74s/it]\n",
      "100%|█████████████████████████████████████████████| 2/2 [00:29<00:00, 14.80s/it]\n",
      "100%|█████████████████████████████████████████████| 2/2 [00:29<00:00, 14.81s/it]\n",
      "100%|█████████████████████████████████████████████| 6/6 [00:30<00:00,  5.16s/it]\n",
      "100%|███████████████████████████████████████████| 4/4 [00:00<00:00, 7093.96it/s]\n",
      "100%|███████████████████████████████████████████| 4/4 [00:00<00:00, 3444.31it/s]\n",
      "100%|██████████████████████████████████████████| 4/4 [00:00<00:00, 13025.79it/s]\n",
      "100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 18.21it/s]\n"
     ]
    }
   ],
   "source": [
    "!python3 Event_Extraction.py ./dataset/test.pkl ./dataset/integrate_SVO.pkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'id': 0,\n",
       " 'date': '20140101',\n",
       " 'title': 'Railcars in North Dakota crude train crash older, less safe -investigators',\n",
       " 'title_SVO': defaultdict(list,\n",
       "             {'main': [{'subject': [('Railcars',\n",
       "                  [defaultdict(list,\n",
       "                               {'predicate': [('in', [])],\n",
       "                                'object': [('North Dakota crude train crash',\n",
       "                                  [])]})])],\n",
       "                'predicate': [],\n",
       "                'object': []}]}),\n",
       " 'integrate_SVO': [[('Railcars', ''),\n",
       "   ('in', ''),\n",
       "   ('North Dakota crude train crash', '')]]}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# output\n",
    "with open('./dataset/integrate_SVO.pkl', 'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "data[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert to Token (Word_Embedding)\n",
    "#### parameters : integrate_SVO(.pkl) / output_path(.pkl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████| 2/2 [00:00<00:00, 5841.65it/s]\n",
      "0it [00:00, ?it/s]                                        | 0/2 [00:00<?, ?it/s]\n",
      "0it [00:00, ?it/s]\n",
      "0it [00:00, ?it/s]\n",
      "100%|████████████████████████████████████████████| 2/2 [00:00<00:00, 199.42it/s]\n",
      "100%|█████████████████████████████████████████████| 2/2 [00:02<00:00,  1.35s/it]\n"
     ]
    }
   ],
   "source": [
    "!python3 Word_Embedding.py ./dataset/integrate_SVO.pkl ./dataset/news_embedding.pkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'date': '20140101',\n",
       " 'SVO': [tensor([[ 0.1677,  0.5718,  0.4815,  ..., -0.0256,  0.6719,  0.4798],\n",
       "          [ 0.3100,  0.7907,  0.0308,  ...,  0.0147,  0.3137, -0.0888],\n",
       "          [-0.4872,  0.1933,  0.1532,  ...,  0.0936,  0.2514,  0.2345]]),\n",
       "  tensor([[-0.3281,  0.1918,  0.2351,  ...,  0.2562,  0.4863,  0.1667],\n",
       "          [-0.2758, -0.2907,  0.3800,  ..., -0.0280,  0.5327, -0.3334],\n",
       "          [-0.4366, -0.0357,  0.0170,  ...,  0.3978,  0.2332,  0.1756]]),\n",
       "  tensor([[-0.2246,  0.3442,  0.3954,  ...,  0.0136,  0.6432,  0.4880],\n",
       "          [-0.7752, -0.5799,  0.0325,  ...,  0.0741, -0.3688,  0.1206],\n",
       "          [-0.0034,  0.1318, -0.0187,  ..., -0.0985,  0.8148, -0.0674]]),\n",
       "  tensor([[ 0.2384,  0.5718,  0.6301,  ...,  0.4288,  0.2098,  0.3832],\n",
       "          [-0.3787, -0.0049, -0.3886,  ..., -0.1581,  0.3159, -0.0910],\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]),\n",
       "  tensor([[-0.4717,  0.1754,  0.1051,  ..., -0.0825,  0.1905, -0.0989],\n",
       "          [ 0.2616,  0.6423,  0.3200,  ..., -0.1873,  0.2177,  0.3495],\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]),\n",
       "  tensor([[ 0.3165,  0.3254,  0.5185,  ...,  0.4871,  1.0901, -0.1416],\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])]}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# output\n",
    "# Each Event: tensor(3, 768)\n",
    "with open('./dataset/news_embedding.pkl', 'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "data[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Event Embedding (AutoEncoder)\n",
    "#### parameters: news_embedding(.pkl) / Autoencoder model(.pth)  / outputpath(.pkl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.36it/s]\n"
     ]
    }
   ],
   "source": [
    "!python3 Event_Embedding.py ./dataset/news_embedding.pkl autoencoder_model.pth ./dataset/event_embedding.pkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['20140101', '20140102'])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# output\n",
    "# dictionary - key: date / value: array(768)\n",
    "with open('./dataset/event_embedding.pkl', 'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "data.keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AutoEncoder Training\n",
    "#### parameters: news_embedding(.pkl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1it [00:01,  1.11s/it]\n",
      "Epoch 000: Loss : 0.50690\n",
      "model saved to autoencoder_1.pth\n"
     ]
    }
   ],
   "source": [
    "!python3 AutoEncoder_train.py ./dataset/news_embedding.pkl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Stock Prediction (LSTM)\n",
    "#### parameters: price(.csv) / event_embedding(.pkl) / LSTM model(.pth) / output_path(.csv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 140/140 [00:00<00:00, 1949.46it/s]\n",
      "acc:  0.6928571428571428\n",
      "F1:  0.6323092677931388\n",
      "precision 0.6988095238095239\n",
      "recall 0.6928571428571428\n"
     ]
    }
   ],
   "source": [
    "!python3 LSTM_test.py ./dataset/price.csv ./dataset/Event_embedding.pkl './model/LSTM(3day_test).pth' ./result/test.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>date</th>\n",
       "      <th>true</th>\n",
       "      <th>predict</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>20190506</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>20190507</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>20190508</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>20190509</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>20190510</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       date  true  predict\n",
       "0  20190506     1        0\n",
       "1  20190507     0        0\n",
       "2  20190508     0        0\n",
       "3  20190509     0        0\n",
       "4  20190510     0        0"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "data = pd.read_csv('./result/test.csv')\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LSTM Training\n",
    "#### parameters: price(.csv) / event_embedding(.pkl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████| 164/164 [00:00<00:00, 635.63it/s]\n",
      "Epoch: 1, train Loss: 0.9098, train accuracy: 0.4069\n",
      "Epoch: 1, valid loss: 0.7837, valid accuracy: 0.1875\n",
      "Epoch: 2, train Loss: 0.7055, train accuracy: 0.6033\n",
      "Epoch: 2, valid loss: 0.8202, valid accuracy: 0.1875\n"
     ]
    }
   ],
   "source": [
    "!python3 LSTM_train.py ./dataset/price.csv ./dataset/Event_embedding.pkl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Market Simulation\n",
    "#### parameters: price(.csv) / predict_ans(.csv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Return:  1.1105337761774103\r\n"
     ]
    }
   ],
   "source": [
    "!python3 Market_Simulation.py ./dataset/price.csv ./result/test.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Baseline Word Embedding\n",
    "#### parameters: news(.pkl) / output_path(.pkl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████| 2/2 [00:00<00:00, 14217.98it/s]\n",
      "0it [00:00, ?it/s]\n",
      "0it [00:00, ?it/s]                                        | 0/2 [00:00<?, ?it/s]\n",
      "0it [00:00, ?it/s]\n",
      "100%|████████████████████████████████████████████| 2/2 [00:00<00:00, 161.91it/s]\n",
      "100%|█████████████████████████████████████████████| 2/2 [00:01<00:00,  1.27it/s]\n"
     ]
    }
   ],
   "source": [
    "!python3 Baseline_preprocess.py ./dataset/test.pkl ./dataset/baseline_test.pkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['20140102', '20140101'])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# output\n",
    "# dictionary - key: date / value: array(768)\n",
    "with open('./dataset/baseline_test.pkl', 'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RTF to Dataframe\n",
    "#### parameters: list(input_path) / output_path(.pkl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "paths = []\n",
    "!python3 RTF_to_Dataframe.py paths ./dataset/news.pklb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}