diff --git a/README.md b/README.md
index b2cf6b3..8e6277c 100644
--- a/README.md
+++ b/README.md
@@ -6,17 +6,19 @@
-
-
+
-
+
+
+
+
@@ -62,6 +64,7 @@ for easily modeling your partially-observed time-series datasets.
👉 Click here to see the example 👀
``` python
+# pip install pypots>=0.4
import numpy as np
from sklearn.preprocessing import StandardScaler
from pygrinder import mcar
@@ -82,7 +85,7 @@ dataset = {"X": X} # X for model input
print(X.shape) # (11988, 48, 37), 11988 samples and each sample has 48 time steps, 37 features
# Model training. This is PyPOTS showtime.
-saits = SAITS(n_steps=48, n_features=37, n_layers=2, d_model=256, d_inner=128, n_heads=4, d_k=64, d_v=64, dropout=0.1, epochs=10)
+saits = SAITS(n_steps=48, n_features=37, n_layers=2, d_model=256, d_ffn=128, n_heads=4, d_k=64, d_v=64, dropout=0.1, epochs=10)
# Here I use the whole dataset as the training set because ground truth is not visible to the model, you can also split it into train/val/test sets
saits.fit(dataset)
imputation = saits.impute(dataset) # impute the originally-missing values and artificially-missing values