Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Add OKT model #31

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions EduKTM/OKT/OKT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import math
import logging
import torch
import torch.nn as nn
import numpy as np
import tqdm
from sklearn import metrics
from scipy.stats import pearsonr

from EduKTM import KTM
from .OKTNet import OKTNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def binary_entropy(target, pred):
loss = target * np.log(np.maximum(1e-10, pred)) + (1.0 - target) * np.log(np.maximum(1e-10, 1.0 - pred))
return np.average(loss) * -1.0


def compute_auc(all_target, all_pred):
return metrics.roc_auc_score(all_target, all_pred)


def compute_accuracy(all_target, all_pred):
all_pred = all_pred.copy()
all_pred[all_pred > 0.5] = 1.0
all_pred[all_pred <= 0.5] = 0.0
return metrics.accuracy_score(all_target, all_pred)


def compute_rmse(all_target, all_pred):
return np.sqrt(metrics.mean_squared_error(all_target, all_pred))


def compute_r2(all_target, all_pred):
return np.power(pearsonr(all_target, all_pred)[0], 2)


def train_one_epoch(net, optimizer, criterion, batch_size, q_data, a_data, e_data, it_data, at_data=None):
net.train()
n = int(math.ceil(len(e_data) / batch_size))
shuffled_ind = np.arange(e_data.shape[0])
np.random.shuffle(shuffled_ind)
q_data = q_data[shuffled_ind]
e_data = e_data[shuffled_ind]
if at_data is not None:
at_data = at_data[shuffled_ind]
a_data = a_data[shuffled_ind]
it_data = it_data[shuffled_ind]

pred_list = []
target_list = []
for idx in tqdm.tqdm(range(n), 'Training'):
optimizer.zero_grad()

q_one_seq = q_data[idx * batch_size: (idx + 1) * batch_size, :]
e_one_seq = e_data[idx * batch_size: (idx + 1) * batch_size, :]
a_one_seq = a_data[idx * batch_size: (idx + 1) * batch_size, :]
it_one_seq = it_data[idx * batch_size: (idx + 1) * batch_size, :]

input_q = torch.from_numpy(q_one_seq).long().to(device)
input_e = torch.from_numpy(e_one_seq).long().to(device)
input_a = torch.from_numpy(a_one_seq).long().to(device)
input_it = torch.from_numpy(it_one_seq).long().to(device)
target = torch.from_numpy(a_one_seq).float().to(device)

input_at = None
if at_data is not None:
at_one_seq = at_data[idx * batch_size: (idx + 1) * batch_size, :]
input_at = torch.from_numpy(at_one_seq).long().to(device)

pred = net(input_q, input_a, input_e, input_it, input_at)

mask = input_e[:, 1:] > 0
masked_pred = pred[:, 1:][mask]
masked_truth = target[:, 1:][mask]

loss = criterion(masked_pred, masked_truth)

loss.backward()

nn.utils.clip_grad_norm_(net.parameters(), max_norm=10)
optimizer.step()

masked_pred = masked_pred.detach().cpu().numpy()
masked_truth = masked_truth.detach().cpu().numpy()
pred_list.append(masked_pred)
target_list.append(masked_truth)

all_pred = np.concatenate(pred_list, axis=0)
all_target = np.concatenate(target_list, axis=0)

loss = binary_entropy(all_target, all_pred)
r2 = compute_r2(all_target, all_pred)
auc = compute_auc(all_target, all_pred)
accuracy = compute_accuracy(all_target, all_pred)

return loss, r2, auc, accuracy


def test_one_epoch(net, batch_size, q_data, a_data, e_data, it_data, at_data=None):
net.eval()
n = int(math.ceil(len(e_data) / batch_size))

pred_list = []
target_list = []
mask_list = []

for idx in tqdm.tqdm(range(n), 'Testing'):
q_one_seq = q_data[idx * batch_size: (idx + 1) * batch_size, :]
e_one_seq = e_data[idx * batch_size: (idx + 1) * batch_size, :]
a_one_seq = a_data[idx * batch_size: (idx + 1) * batch_size, :]
it_one_seq = it_data[idx * batch_size: (idx + 1) * batch_size, :]

input_q = torch.from_numpy(q_one_seq).long().to(device)
input_e = torch.from_numpy(e_one_seq).long().to(device)
input_a = torch.from_numpy(a_one_seq).long().to(device)
input_it = torch.from_numpy(it_one_seq).long().to(device)
target = torch.from_numpy(a_one_seq).float().to(device)

input_at = None
if at_data is not None:
at_one_seq = at_data[idx * batch_size: (idx + 1) * batch_size, :]
input_at = torch.from_numpy(at_one_seq).long().to(device)

with torch.no_grad():
pred = net(input_q, input_a, input_e, input_it, input_at)

mask = input_e[:, 1:] > 0
masked_pred = pred[:, 1:][mask].detach().cpu().numpy()
masked_truth = target[:, 1:][mask].detach().cpu().numpy()

pred_list.append(masked_pred)
target_list.append(masked_truth)
mask_list.append(mask.long().cpu().numpy())

all_pred = np.concatenate(pred_list, axis=0)
all_target = np.concatenate(target_list, axis=0)
mask_list = np.concatenate(mask_list, axis=0)

loss = binary_entropy(all_target, all_pred)
r2 = compute_r2(all_target, all_pred)
auc = compute_auc(all_target, all_pred)
accuracy = compute_accuracy(all_target, all_pred)
rmse = compute_rmse(all_target, all_pred)

return loss, rmse, r2, auc, accuracy


class OKT(KTM):
def __init__(self, n_at, n_it, n_exercise, n_question, d_e, d_q, d_a, d_at, d_p, d_h, batch_size=64, dropout=0.2):
super(OKT, self).__init__()

self.okt_net = OKTNet(n_question, n_exercise, n_it, n_at, d_e, d_q, d_a, d_at, d_p, d_h,
dropout=dropout).to(device)
self.batch_size = batch_size

def train(self, train_data, test_data=None, *, epoch: int, lr=0.002, lr_decay_step=15, lr_decay_rate=0.5,
filepath=None) -> ...:
optimizer = torch.optim.Adam(self.okt_net.parameters(), lr=lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_decay_step, gamma=lr_decay_rate)
criterion = nn.BCELoss()
best_train_auc, best_test_auc = .0, .0

for idx in range(epoch):
train_loss, train_r2, train_auc, train_accuracy = train_one_epoch(self.okt_net, optimizer, criterion,
self.batch_size, *train_data)
print("[Epoch %d] LogisticLoss: %.6f" % (idx, train_loss))
if train_auc > best_train_auc:
best_train_auc = train_auc

if test_data is not None:
_, _, test_r2, test_auc, test_accuracy = self.eval(test_data)
print("[Epoch %d] r2: %.6f, auc: %.6f, accuracy: %.6f" % (idx, test_r2, test_auc, test_accuracy))
scheduler.step()
if test_auc > best_test_auc:
best_test_auc = test_auc
if filepath is not None:
self.save(filepath)

return best_train_auc, best_test_auc

def eval(self, test_data) -> ...:
return test_one_epoch(self.okt_net, self.batch_size, *test_data)

def save(self, filepath) -> ...:
torch.save(self.okt_net.state_dict(), filepath)
logging.info("save parameters to %s" % filepath)

def load(self, filepath) -> ...:
self.okt_net.load_state_dict(torch.load(filepath, map_location='cpu'))
logging.info("load parameters from %s" % filepath)
89 changes: 89 additions & 0 deletions EduKTM/OKT/OKTNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch
import torch.nn as nn
from torch.nn.init import xavier_uniform_

from .modules import UKSE, KSE, OTE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class OKTNet(nn.Module):
def __init__(self, n_skill, n_exercise, n_it, n_at, d_e, d_q, d_a, d_at, d_p, d_h, dropout=0.05):
super(OKTNet, self).__init__()
self.device = device

self.n_skill = n_skill
self.n_at = n_at
self.d_h = d_h
self.d_a = d_a

d_it = d_h
self.it_embed = nn.Embedding(n_it + 1, d_it)
xavier_uniform_(self.it_embed.weight)
self.at_embed = nn.Embedding(n_at + 1, d_at)
xavier_uniform_(self.at_embed.weight)
self.answer_embed = nn.Embedding(2, d_a)
xavier_uniform_(self.answer_embed.weight)
self.exercise_embed = nn.Embedding(n_exercise + 1, d_e)
xavier_uniform_(self.exercise_embed.weight)
self.skill_embed = nn.Embedding(n_skill + 1, d_q)
xavier_uniform_(self.skill_embed.weight)

self.linear_q = nn.Linear(d_e + d_q, d_p)
xavier_uniform_(self.linear_q.weight)
if n_at == 0:
self.linear_x = nn.Linear(d_p + d_a, d_h)
else:
self.linear_x = nn.Linear(d_at + d_p + d_a, d_h)
xavier_uniform_(self.linear_x.weight)
self.ukse = UKSE(d_h)
self.kse = KSE(d_h)
self.ote = OTE(d_it, d_h, d_h)

self.sig = nn.Sigmoid()
self.tanh = nn.Tanh()

self.dropout = nn.Dropout(dropout)

self.predict = nn.Sequential(
nn.Linear(d_h + d_p, 256),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(256, 1)
)

def forward(self, kc_data, a_data, e_data, it_data, at_data):
# prepare data
batch_size, seq_len = kc_data.size(0), kc_data.size(1)

E = self.exercise_embed(e_data)
KC = self.skill_embed(kc_data)
IT = self.it_embed(it_data)
Ans = self.answer_embed(a_data)
Q = self.linear_q(torch.cat((E, KC), 2))
if self.n_at == 0:
X = self.linear_x(torch.cat((Q, Ans), 2))
else:
AT = self.at_embed(at_data)
X = self.linear_x(torch.cat((Q, Ans, AT), 2))

previous_h = xavier_uniform_(torch.zeros(1, self.d_h)).repeat(batch_size, 1).to(self.device)
v = xavier_uniform_(torch.empty(1, self.d_h)).repeat(batch_size, 1).to(self.device)
pred = torch.zeros(batch_size, seq_len, 1).to(self.device)

for t in range(seq_len):
it_embed = IT[:, t]
q = Q[:, t]
x = X[:, t]

# predict
updated_h = self.ukse(previous_h, v, it_embed)
pred[:, t] = self.sig(self.predict(torch.cat((updated_h, q), 1)))

# update
h = self.kse(x, updated_h)
v = self.ote(previous_h, h, it_embed, v)

previous_h = h

return pred.squeeze(-1)
4 changes: 4 additions & 0 deletions EduKTM/OKT/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# coding: utf-8
# 2022/9/27 @ sone

from .OKT import OKT
49 changes: 49 additions & 0 deletions EduKTM/OKT/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
import torch.nn as nn
from torch.nn.init import xavier_uniform_


class OTE(nn.Module):
def __init__(self, d_h, d_it, d_v):
super(OTE, self).__init__()
self.linear_v = nn.Linear(2 * d_h + d_it, d_v)
self.linear_p = nn.Linear(d_it + d_v, d_v)
self.sig = nn.Sigmoid()
self.tanh = nn.Tanh()

def forward(self, previous_h, h, it, v):
delta = torch.cat((previous_h, h), 1)
v_prime = self.tanh(self.linear_v(torch.cat((delta, it), 1)))
p = self.sig(self.linear_p(torch.cat((v, it), 1)))
v = (1 - p) * v + p * v_prime
return v


class UKSE(nn.Module):
def __init__(self, d_h):
super(UKSE, self).__init__()
self.linear_h = nn.Linear(3 * d_h, d_h)
self.linear_p = nn.Linear(3 * d_h, d_h)
self.sig = nn.Sigmoid()
self.tanh = nn.Tanh()

def forward(self, h, v, it):
h_prime = self.tanh(self.linear_h(torch.cat((h, v, it), 1)))
p = self.sig(self.linear_p(torch.cat((h, v, it), 1)))
return (1 - p) * h + p * h_prime


class KSE(nn.Module):
def __init__(self, d_h):
super(KSE, self).__init__()
self.linear_h = nn.Linear(2 * d_h, d_h)
self.linear_p = nn.Linear(2 * d_h, d_h)

self.sig = nn.Sigmoid()
self.tanh = nn.Tanh()

def forward(self, x, hr):
h_tilde = self.tanh(self.linear_h(torch.cat((x, hr), 1)))
p = self.sig(self.linear_p(torch.cat((x, hr), 1)))
hx = (1 - p) * hr + p * h_tilde
return hx
1 change: 1 addition & 0 deletions EduKTM/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from .LPKT import LPKT
from .GKT import GKT
from .DKVMN import DKVMN
from .OKT import OKT
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Knowledge Tracing (KT), which aims to monitor students’ evolving knowledge sta
* [GKT](EduKTM/GKT)[[doc]](docs/GKT.md) [[example]](examples/GKT)
* [AKT](EduKTM/AKT) [[doc]](docs/AKT.md) [[example]](examples/AKT)
* [LPKT](EduKTM/LPKT) [[doc]](docs/LPKT.md) [[example]](examples/LPKT)
* [OKT](EduKTM/OKT) [[doc]](docs/OKT.md) [[example]](examples/OKT)

## Contribute

Expand Down
3 changes: 3 additions & 0 deletions docs/OKT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Offline-aware Knowledge Tracing(OKT)

The details of OKT will be given after the paper is published.
Loading