forked from go-interpreter/chezgo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmatrix.ss
271 lines (236 loc) · 7.73 KB
/
matrix.ss
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
#|
usage:
> (import (my-matrix))
> (sanity)
#t
> (run-bench)
500 x 500 matrix multiply in Chez took 2472 msec
500 x 500 matrix multiply in Chez took 2474 msec
...
|#
(library (my-matrix (1))
(export run-bench sanity)
(import (chezscheme))
;;; reference: https://www.scheme.com/tspl3/examples.html
;;; make-matrix creates a matrix (a vector of vectors).
(define make-matrix
(lambda (rows columns)
(do ((m (make-vector rows))
(i 0 (+ i 1)))
((= i rows) m)
(vector-set! m i (make-vector columns)))))
;;; matrix? checks to see if its argument is a matrix.
;;; It isn't foolproof, but it's generally good enough.
(define matrix?
(lambda (x)
(and (vector? x)
(> (vector-length x) 0)
(vector? (vector-ref x 0)))))
;; matrix-rows returns the number of rows in a matrix.
(define matrix-rows
(lambda (x)
(vector-length x)))
;; matrix-columns returns the number of columns in a matrix.
(define matrix-columns
(lambda (x)
(vector-length (vector-ref x 0))))
;;; matrix-ref returns the jth element of the ith row.
(define matrix-ref
(lambda (m i j)
(vector-ref (vector-ref m i) j)))
;;; matrix-set! changes the jth element of the ith row.
(define matrix-set!
(lambda (m i j x)
(vector-set! (vector-ref m i) j x)))
;;; mul is the generic matrix/scalar multiplication procedure
(define mul
(lambda (x y)
;; mat-sca-mul multiplies a matrix by a scalar.
(define mat-sca-mul
(lambda (m x)
(let* ((nr (matrix-rows m))
(nc (matrix-columns m))
(r (make-matrix nr nc)))
(do ((i 0 (+ i 1)))
((= i nr) r)
(do ((j 0 (+ j 1)))
((= j nc))
(matrix-set! r i j
(fl* x (matrix-ref m i j))))))))
;; mat-mat-mul multiplies one matrix by another, after verifying
;; that the first matrix has as many columns as the second
;; matrix has rows.
(define mat-mat-mul
(lambda (m1 m2)
(let* ((nr1 (matrix-rows m1))
(nr2 (matrix-rows m2))
(nc2 (matrix-columns m2))
(r (make-matrix nr1 nc2))
(tot 0))
(if (not (= (matrix-columns m1) nr2))
(match-error m1 m2))
(do ((i 0 (+ i 1)))
((= i nr1) r)
(do ((k 0 (+ k 1)))
((= k nr2))
(let ((ith-input-row (vector-ref m1 i))
(kth-input-row (vector-ref m2 k))
(ith-output-row (vector-ref r i)))
(do ((j 0 (+ j 1)))
((= j nc2))
(set! tot (vector-ref ith-output-row j))
(set! tot (+ tot
(fl* (vector-ref ith-input-row k)
(vector-ref kth-input-row j))))
(vector-set! ith-output-row j tot))))))))
;; type-error is called to complain when mul receives an invalid
;; type of argument.
(define type-error
(lambda (what)
(error 'mul
"~s is not a number or matrix"
what)))
;; match-error is called to complain when mul receives a pair of
;; incompatible arguments.
(define match-error
(lambda (what1 what2)
(error 'mul
"~s and ~s are incompatible operands"
what1
what2)))
;; body of mul; dispatch based on input types
(cond
((number? x)
(cond
((number? y) (* x y))
((matrix? y) (mat-sca-mul y x))
(else (type-error y))))
((matrix? x)
(cond
((number? y) (mat-sca-mul x y))
((matrix? y) (mat-mat-mul x y))
(else (type-error y))))
(else (type-error x)))))
(define (fill-random m)
(let* ((nr (matrix-rows m))
(nc (matrix-columns m)))
(do ((i 0 (+ i 1)))
((= i nr))
(do ((j 0 (+ j 1)))
((= j nc))
(matrix-set! m i j (/ (random 100) (+ 2.0 (random 100))))
))))
(define (bench a b) (mul a b))
(define (run-bench)
(collect)
(do ((runs 0 (+ 1 runs))) ((>= runs 10))
(do ((sz 500 (+ sz 100)))
((>= sz 600))
(let* ((a (make-matrix sz sz))
(b (make-matrix sz sz)))
(fill-random a)
(fill-random b)
(let*
((t0 (real-time))
(blah (mul a b))
(t1 (real-time)))
(format #t "~s x ~s matrix multiply in Chez took ~s msec" sz sz (- t1 t0))
(newline)
)))))
(define (fill-sequential m start)
(let* ((nr (matrix-rows m))
(nc (matrix-columns m))
(val start))
(do ((i 0 (+ i 1)))
((= i nr))
(do ((j 0 (+ j 1)))
((= j nc))
(matrix-set! m i j val)
(set! val (+ 1 val))
))))
(define (matrix-same m1 m2)
(let* (
(nr1 (matrix-rows m1))
(nr2 (matrix-rows m2))
(nc1 (matrix-columns m1))
(nc2 (matrix-columns m2))
(final #t)
)
(if (not (eq? nc1 nc2))
(set! final #f)
(if (not (eq? nr1 nr2))
(set! final #f)
(do ((i 0 (+ i 1)))
((= i nr1))
(do ((j 0 (+ j 1)))
((= j nc1))
(let*
((v1 (matrix-ref m1 i j))
(v2 (matrix-ref m2 i j))
)
(if (not (eq? v1 v2))
(set! final #f))
)))))
final))
;; sanity test, is our logic right?
;; expect that #(#(1 2) #(3 4)) x #(#(5 6) #(7 8)) == #(#(19 22) #(43 50))
(define (sanity)
(let* (
(expect (make-matrix 2 2))
(a (make-matrix 2 2))
(b (make-matrix 2 2))
)
(matrix-set! expect 0 0 19)
(matrix-set! expect 0 1 22)
(matrix-set! expect 1 0 43)
(matrix-set! expect 1 1 50)
(fill-sequential a 1)
(fill-sequential b 5)
(let* (
(obs (mul a b))
)
;; verify this gives us true
(matrix-same obs expect)
)))
#|
chez scheme timings, on mac book pro:
(with (optimize-level 3); as we go ~ 10% faster)
scheme --optimize-level 3 ./matrix.ss
Chez Scheme Version 9.5.1
Copyright 1984-2017 Cisco Systems, Inc.
;; top-level bindings, without a library wrapper:
500 x 500 matrix multiply in Chez took 2606 msec
500 x 500 matrix multiply in Chez took 2605 msec
500 x 500 matrix multiply in Chez took 2571 msec
500 x 500 matrix multiply in Chez took 2634 msec
500 x 500 matrix multiply in Chez took 2597 msec
500 x 500 matrix multiply in Chez took 2603 msec
500 x 500 matrix multiply in Chez took 2565 msec
500 x 500 matrix multiply in Chez took 2535 msec
500 x 500 matrix multiply in Chez took 2587 msec
500 x 500 matrix multiply in Chez took 2547 msec
;; inside the my-matrix library:
500 x 500 matrix multiply in Chez took 2435 msec
500 x 500 matrix multiply in Chez took 2498 msec
500 x 500 matrix multiply in Chez took 2486 msec
500 x 500 matrix multiply in Chez took 2496 msec
500 x 500 matrix multiply in Chez took 2465 msec
500 x 500 matrix multiply in Chez took 2499 msec
500 x 500 matrix multiply in Chez took 2492 msec
500 x 500 matrix multiply in Chez took 2532 msec
500 x 500 matrix multiply in Chez took 2463 msec
500 x 500 matrix multiply in Chez took 2526 msec
;; update! shifting to (fl*) instead of (*) shaved
;; off 17%
500 x 500 matrix multiply in Chez took 2075 msec
500 x 500 matrix multiply in Chez took 2040 msec
500 x 500 matrix multiply in Chez took 2054 msec
500 x 500 matrix multiply in Chez took 2059 msec
500 x 500 matrix multiply in Chez took 2066 msec
500 x 500 matrix multiply in Chez took 2048 msec
500 x 500 matrix multiply in Chez took 2053 msec
500 x 500 matrix multiply in Chez took 2112 msec
500 x 500 matrix multiply in Chez took 2064 msec
500 x 500 matrix multiply in Chez took 2060 msec
|#
) ;; end my-matrix library