From 747956e4179ee8fdca67c18e02946b790551f2f2 Mon Sep 17 00:00:00 2001 From: JJ Date: Tue, 7 Feb 2023 03:19:04 +0900 Subject: [PATCH] [feat] implement 'svm_loss_vectorized' function #1 --- .../cs231n/classifiers/linear_svm.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/cs231n_2022/assignment1/cs231n/classifiers/linear_svm.py b/cs231n_2022/assignment1/cs231n/classifiers/linear_svm.py index 83613fd..72b0404 100644 --- a/cs231n_2022/assignment1/cs231n/classifiers/linear_svm.py +++ b/cs231n_2022/assignment1/cs231n/classifiers/linear_svm.py @@ -84,7 +84,15 @@ def svm_loss_vectorized(W, X, y, reg): ############################################################################# # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)***** - pass + num_train = X.shape[0] + + scores = np.matmul(X, W) + correct_class_scores = scores[range(num_train), y].reshape(-1, 1) + + margins = np.maximum(0, scores - correct_class_scores + 1) + margins[range(num_train), y] = 0 + + loss = np.sum(margins) / num_train + reg * np.sum(W * W) # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)***** @@ -98,8 +106,13 @@ def svm_loss_vectorized(W, X, y, reg): # loss. # ############################################################################# # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)***** + + # ref: https://mainpower4309.tistory.com/28 + dScores = np.zeros(scores.shape) + dScores[margins > 0] = 1 + dScores[range(num_train), y] -= np.sum(dScores, axis=1) - pass + dW = np.matmul(X.T, dScores) / num_train + 2 * reg * W # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****