-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstable_diffusion.py
2081 lines (1719 loc) · 84.3 KB
/
stable_diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
# https://arxiv.org/pdf/2112.10752.pdf
# https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md
# Blatantly copied from:
# https://github.com/geohot/tinygrad/blob/master/examples/stable_diffusion.py
# https://github.com/geohot/tinygrad/blob/master/LICENSE
import os
import gzip
import argparse
import math
import re
import pickle
import zipfile
import io
import struct
import sys
import weakref
import operator
import itertools
import functools
from copy import copy
from functools import lru_cache
from collections import namedtuple
from collections import defaultdict
from enum import Enum
from typing import Optional, Tuple, Union, List, Dict
from typing import Type, NamedTuple, Any
from tqdm import tqdm
from PIL import Image
import numpy as np
DEBUG = int(os.getenv("DEBUG", "0"))
# these are the llops your accelerator must implement, along with toCpu
UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN", "RECIPROCAL"])
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "EXPAND", "FLIP", "STRIDED", "PAD", "SHRINK"])
ProcessingOps = Enum("ProcessingOps", ["CONV"])
LoadOps = Enum("LoadOps", ["FROMCPU", "CONTIGUOUS"])
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps]
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[ProcessingOps], Type[LoadOps]]
GRAPH = int(os.getenv("GRAPH", "0"))
# **** debugging and graphing ****
cnts : Dict[OpType, int] = defaultdict(int)
global_num_max = 0
def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[DeviceBuffer]):
cnts[optype] += 1
if DEBUG >= 3:
print(f"{op} : {', '.join([str(x.shape) for x in inp])} -> {ret.shape}")
if GRAPH:
def nm(x):
global global_num_max
if not hasattr(x, 'global_num'):
setattr(x, 'global_num', global_num_max)
global_num_max += 1
return f"<<< {x.global_num} >>>"
top_colors = {LoadOps: '#FFFF80', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", ProcessingOps: "#ff8080"}
dashed = (optype == LoadOps and hasattr(ret, "_backing")) or (hasattr(ret, "st") and not ret.st.contiguous) # type: ignore
for x in inp:
if len(op) <= 2:
sop = '.'.join([str(y).split(".")[1] for y in op][::-1])
elif len(op) <= 4:
sop = '.'.join([str(y).split(".")[1][0:2] for y in op][::-1])
else:
sop = str(len(op))
G.add_edge(nm(x), nm(ret), label=sop)
if 'label' not in G.nodes[nm(x)]:
G.nodes[nm(x)]['label'] = str(x.shape)
if nm(ret) not in G.nodes:
G.add_node(nm(ret))
if optype == ReduceOps:
G.nodes[nm(ret)]['label'] = str(set(x.shape for x in inp))+"\n"+str(ret.shape)
else:
G.nodes[nm(ret)]['label'] = str(ret.shape)
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if dashed else '')) if optype in top_colors else "#ffffff"
G.nodes[nm(ret)]['style'] = 'filled, dashed' if dashed else 'filled'
class LazyOp(NamedTuple):
op: Op
# Any == Union[LazyOp, LazyBuffer, DeviceBuffer]
src: Tuple[Any, ...] # type: ignore
arg: Any = None
# TODO: add dest to support multiple outputs
# Any == Union[LazyBuffer, DeviceBuffer]
def get_buffers(op:LazyOp) -> List[Any]: return functools.reduce(operator.add, [get_buffers(x) if isinstance(x, LazyOp) else [x] for x in op.src], [])
def get_lazyops(op:LazyOp) -> List[LazyOp]: return functools.reduce(operator.add, [get_lazyops(x) for x in op.src if isinstance(x, LazyOp)], [op])
# a placeholder class to extend by the exec classes
class DeviceBuffer:
shape: Any # should be Tuple[int, ...] but ndarray and torch.tensor have incompatible types
# extend this if you don't have an exec_ast function
# used in CPUBuffer and TorchBuffer
class GenericExecAST(DeviceBuffer):
@classmethod
def exec_ast(cls, ast:LazyOp, preprocess=lambda x: x):
srcs = [cls.exec_ast(x, preprocess) if isinstance(x, LazyOp) else preprocess(x) for x in ast.src]
if ast.op in UnaryOps:
ret = srcs[0].unary_op(ast.op)
elif ast.op in BinaryOps:
assert srcs[0].shape == srcs[1].shape, f"BinaryOps shape mismatch {srcs[0].shape} != {srcs[1].shape}"
ret = srcs[0].binary_op(ast.op, srcs[1])
elif ast.op in ReduceOps:
assert all(r == n or n == 1 for r,n in zip(srcs[0].shape, ast.arg)), f"ReduceOps can't reduce {srcs[0].shape} -> {ast.arg}"
ret = srcs[0].reduce_op(ast.op, ast.arg)
elif ast.op in MovementOps:
ret = srcs[0].movement_op(ast.op, ast.arg)
elif ast.op in ProcessingOps:
ret = srcs[0].processing_op(ast.op, srcs[1], ast.arg)
else:
raise Exception("unknown op")
return ret
class CPUBuffer(np.ndarray, GenericExecAST):
fxn_for_op = {
UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.RELU: lambda x: x.relu(),
UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.SIGN: lambda x: x.sign(), UnaryOps.RECIPROCAL: lambda x: 1.0/x,
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul,
BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow, BinaryOps.CMPEQ: lambda x,y: (x==y).float(),
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
ReduceOps.MAX: lambda x, new_shape: x.amax(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)]
}
def relu(x): return np.maximum(x, 0)
def exp(x): return np.exp(x)
def log(x): return np.log(x)
def sign(x): return np.sign(x)
def float(x): return x.astype(np.float32)
def flip(x, axis): return np.flip(x, axis)
def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs)
def permute(x, order): return x.transpose(order)
def pad(x, padding): return np.pad(x, padding).view(CPUBuffer)
def expand(x, new_shape): return np.broadcast_to(x, new_shape).view(CPUBuffer)
def strided(x, arg): return np.lib.stride_tricks.as_strided(x.ravel().reshape(x.shape), shape=[y[0] for y in arg], strides=[y[1]*x.dtype.itemsize for y in arg]).view(CPUBuffer)
@staticmethod
def fromCPU(x): return x.view(CPUBuffer)
def toCPU(x): return x
def unary_op(x, op): return CPUBuffer.fxn_for_op[op](x)
def binary_op(x, op, y): return CPUBuffer.fxn_for_op[op](x, y)
def reduce_op(x, op, new_shape): return CPUBuffer.fxn_for_op[op](x, new_shape)
def movement_op(x, op, arg=None): return CPUBuffer.fxn_for_op[op](x, arg) if op in CPUBuffer.fxn_for_op else getattr(x, op.name.lower())(arg)
def processing_op(x,op,w,C):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
tx = x.movement_op(MovementOps.STRIDED, (
(C.bs, C.groups*C.cin*x.shape[2]*x.shape[3]), (C.groups, C.cin*x.shape[2]*x.shape[3]),
(C.oy, C.sy*x.shape[3]), (C.ox, C.sx), (C.cin, x.shape[2]*x.shape[3]), (C.H, C.dy*x.shape[3]), (C.W, C.dx)))
tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W)
out = np.einsum("nGhwCHW, GkCHW -> nGkhw", tx.ravel().reshape(tx.shape), tw.ravel().reshape(tw.shape))
return out.reshape(C.bs, C.groups*C.rcout, C.oy, C.ox).view(CPUBuffer)
class GenericShape(GenericExecAST):
def __init__(self, shape, flops=0): self.shape, self.flops = shape, flops
def unary_op(self, op:UnaryOps): return GenericShape(self.shape, self.flops + prod(self.shape))
def binary_op(self, op:BinaryOps, y): return GenericShape(self.shape, self.flops + y.flops + prod(self.shape))
def reduce_op(self, op:ReduceOps, new_shape:Tuple[int, ...]): return GenericShape(new_shape, self.flops + prod(self.shape))
def movement_op(self, op:MovementOps, arg): return GenericShape(ShapeTracker(self.shape).movement_op(op, arg).shape, self.flops)
def processing_op(self, op:ProcessingOps, w, C): return GenericShape(C.out_shape, float("nan")) # TODO: add flops for this
def get_lazyop_info(ast:LazyOp): return GenericShape.exec_ast(ast, lambda x: GenericShape(x.shape))
# assumes you are using ShapeTracker
# used in GPUBuffer, OpenCLBuffer, and LLVMBuffer
class ExplicitExecAST(DeviceBuffer):
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf=None):
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
self.shape = self.st.shape
@classmethod
def exec_ast(cls, ast:LazyOp): raise NotImplementedError("must be implemented")
# universal
def unary_op(self, op:UnaryOps): return type(self)(self.shape).exec_ast(LazyOp(op=op, src=(self,)))
def binary_op(self, op:BinaryOps, y): return type(self)(self.shape).exec_ast(LazyOp(op=op, src=(self, y)))
def reduce_op(self, op:ReduceOps, new_shape:Tuple[int, ...]): return type(self)(new_shape).exec_ast(LazyOp(op=op, src=(self,), arg=new_shape))
# universal for shape tracked
def movement_op(self, op:MovementOps, arg): return type(self)(ShapeTracker(self.st).movement_op(op, arg), self)
def contiguous_op(self): return self if self.st.contiguous else self.unary_op(UnaryOps.NOOP)
def divmodidx(acc, d, mod=True):
lr = f"(idx//{acc})" if acc != 1 else "idx"
return f"({lr}%{d})" if mod else lr # don't mod the top shape dimension
@functools.lru_cache(maxsize=None)
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tuple[int, int]]:
assert len(shape) == len(strides)
ret = [(shape[0], strides[0])]
for i in range(1, len(shape)):
if (strides[i] != 0 and ret[-1][1] == shape[i]*strides[i]) or ret[-1][0] == 1 or (strides[i] == 0 and ret[-1][1] == 0):
ret[-1] = (ret[-1][0] * shape[i], strides[i])
else:
ret.append((shape[i], strides[i]))
return ret
class View:
def __init__(self, shape, strides, offset:int=0):
self.shape, self.strides, self.offset = tuple(shape), tuple(strides), offset
self.shape_strides = to_shape_strides(self.shape, self.strides)
def __repr__(self): return f"View<{self.shape}, {self.strides}, {self.offset}>"
@functools.cached_property
def contiguous(self):
return self.offset == 0 and all(s1 == s2 or s == 1 for s,s1,s2 in zip(self.shape, self.strides, strides_for_shape(self.shape)))
@functools.cached_property
def expr(self):
ret = [f"{self.offset}"] if self.offset != 0 else []
acc = 1
for i,(d,s) in enumerate(self.shape_strides[::-1]):
if d != 1 and s != 0:
lr = divmodidx(acc, d, i != len(self.shape_strides)-1 and d != prod(self.shape))
lr = f"({lr}*{s})" if s != 1 else lr
ret.append(lr)
acc *= d
return 'idx=' + ('+'.join(ret) if len(ret) > 0 else "0")
class ZeroView:
def __init__(self, old_shape, arg):
self.old_shape, self.arg, self.shape = old_shape, arg, []
expr, acc = ['valid'], 1
for s,(x,y) in list(zip(old_shape, arg))[::-1]:
self.shape = [y-x] + self.shape
base = divmodidx(acc, self.shape[0], len(self.shape) != len(old_shape)) + f"+{x}"
expr += ([f"(({base}) >= 0)"] if x < 0 else []) + ([f"(({base}) < {s})"] if y > s else [])
acc *= self.shape[0]
self.expr = 'valid=' + ' && '.join(expr)
ViewTypes = Union[View, ZeroView]
@functools.lru_cache(maxsize=None)
def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
strides = [1]
for d in shape[::-1][:-1]:
strides = [d*strides[0]] + strides
return tuple(strides)
@functools.lru_cache(maxsize=None)
def view_from_shape(shape:Tuple[int, ...]) -> View:
assert all(isinstance(x, int) for x in shape) and len(shape) != 0
return View(tuple(shape), strides_for_shape(shape))
class ShapeTracker:
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]]):
self.views : List[ViewTypes] = shape.views[:] if isinstance(shape, ShapeTracker) else [view_from_shape(shape)]
def __repr__(self): return f"{'Complex' if len(self.views) > 1 else ''}ShapeTracker<{self.shape}, {self.views}>"
@property
def contiguous(self): return len(self.views) == 1 and self.views[-1].contiguous
@property
def shape(self): return self.views[-1].shape
@property
def strides(self): return self.views[-1].strides
@property
def offset(self): return self.views[-1].offset
def expr(self): return ';'.join([v.expr for v in self.views[::-1] if v.expr != 'idx=idx' and v.expr != 'valid=valid'])
def movement_op(self, op, arg):
getattr(self, str(op).split(".")[1].lower())(*arg)
return self
def needs_valid(self): return any(isinstance(v, ZeroView) for v in self.views)
# TODO: do we really need this for conv?
# if we replace, confirm the ops taken fold into one view
def strided(self, *arg):
view = View([x[0] for x in arg], [x[1] for x in arg])
# TODO: this does not always require a new view if non contiguous
if self.contiguous:
self.views[-1] = view
else:
self.views.append(view)
def reshape(self, *new_shape):
assert all(isinstance(x, int) for x in new_shape)
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
# check if this is adding or removing 1s (only)
if tuple([x for x in self.shape if x != 1]) == tuple([x for x in new_shape if x != 1]):
old_strides = [y for x,y in zip(self.shape, self.strides) if x != 1]
new_strides = [0 if x == 1 else old_strides.pop(0) for x in new_shape]
self.views[-1] = View(new_shape, new_strides, self.offset)
return
view = View(new_shape, strides_for_shape(new_shape))
if self.contiguous:
self.views[-1] = view # NOTE: if it's contiguous it can't have an offset
else:
self.views.append(view)
def permute(self, *axis):
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis)
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
self.views[-1] = View([self.shape[a] for a in axis], [self.strides[a] for a in axis], self.offset)
# TODO: this is a special case of slice with strides, remove it
# though it's nice that it can't change size
def flip(self, *axis): self.stride(*[-1 if i in axis else 1 for i in range(len((self.shape)))])
# *** under this line are not invertible ***
# TODO: take this functionality out of slice
def pad(self, *arg):
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
return self.shrink(*[(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
# TODO: take the pad functionality out of shrink
def shrink(self, *arg):
assert len(arg) == len(self.shape)
offset = sum([self.strides[i]*x for i,(x,_) in enumerate(arg)])
zeroview = ZeroView(self.shape, arg)
self.views[-1] = View([y-x for x,y in arg], self.strides, self.offset+offset)
if zeroview.expr != "valid=valid":
# if we add a ZeroView, we add another (stock) view also for modding
self.views += [zeroview, View(self.shape, strides_for_shape(self.shape))]
def expand(self, *new_shape):
assert all(isinstance(x, int) for x in new_shape)
assert all(x == y or x == 1 for x,y in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
strides = [s if x == y else 0 for s,(x,y) in zip(self.strides, zip(self.shape, new_shape))]
self.views[-1] = View(new_shape, strides, self.offset)
# TODO: combine with slice? this doesn't require a ZeroView, though slice shouldn't always either
def stride(self, *mul):
assert all(isinstance(x, int) for x in mul)
strides = [z*m for z,m in zip(self.strides, mul)]
new_shape = [(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)]
offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
self.views[-1] = View(new_shape, strides, self.offset + offset)
def dedup(x): return list(dict.fromkeys(x)) # retains list order
def prod(x): return math.prod(x)
def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0], tuple) or isinstance(x[0], list) else tuple(x)
def argsort(x): return sorted(range(len(x)), key=x.__getitem__) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def reduce_shape(shape, axis): return tuple(1 if i in axis else shape[i] for i in range(len(shape)))
def shape_to_axis(old_shape, new_shape):
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
return tuple([i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b])
ConvArgs = namedtuple('ConvArgs', ['H', 'W', 'groups', 'rcout', 'cin', 'oy', 'ox', 'iy', 'ix', 'sy', 'sx', 'bs', 'cout', 'py', 'py_', 'px', 'px_', 'dy', 'dx', 'out_shape'])
def get_conv_args(x_shape, w_shape, stride=1, groups=1, padding=0, dilation=1, out_shape=None):
# TODO: https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#tensor-layout
cout,cin,H,W = w_shape
sy,sx = (stride, stride) if isinstance(stride, int) else stride
if not isinstance(padding, int) and len(padding) == 4:
px,px_,py,py_ = padding
else:
py,px = (padding, padding) if isinstance(padding, int) else padding
py_, px_ = py, px
dy,dx = (dilation, dilation) if isinstance(dilation, int) else dilation
bs,cin_,iy,ix = x_shape
# this can change px_ and py_ to make the out_shape right
# TODO: copy padding names from http://nvdla.org/hw/v1/ias/unit_description.html
if out_shape is not None:
py_ = (out_shape[2] - 1) * sy + 1 + dy * (H-1) - iy - py
px_ = (out_shape[3] - 1) * sx + 1 + dx * (W-1) - ix - px
# TODO: should be easy to support asymmetric padding by changing output size
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html describes these sizes well
oy = (iy + py + py_ - dy * (H-1) - 1)//sy + 1
ox = (ix + px + px_ - dx * (W-1) - 1)//sx + 1
if cin*groups != cin_:
raise Exception(f"Input Tensor shape {x_shape} does not match the shape of the weights {w_shape}. ({cin*groups} vs. {cin_})")
assert cout % groups == 0 and (out_shape is None or out_shape == (bs, cout, oy, ox))
return ConvArgs(H, W, groups, cout//groups, cin, oy, ox, iy, ix, sy, sx, bs, cout, py, py_, px, px_, dy, dx, (bs, cout, oy, ox))
def get_available_llops():
_buffers, DEFAULT = {}, "CPU"
_buffers["CPU"] = CPUBuffer
return _buffers, DEFAULT
# lazy can recurse a lot
sys.setrecursionlimit(10000)
OPT = int(os.getenv("OPT", "1"))
NOCONV = int(os.getenv("NOCONV", "0"))
# TODO: movement ops that only change shape are really nops. treat them as such
REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
MERGE_ELEMENTWISE_OPS, MERGE_ONE_REDUCE_INTO_ELEMENTWISE = OPT>=2, OPT>=2
SHUFFLE_PAD_OPS = OPT>=3 # NOTE: 0/0 is NaN if you pad, so this can change the output
# **** enumerate supported devices ****
class Device:
_buffers, DEFAULT = get_available_llops()
for name in _buffers.keys():
vars()[name] = name
# **** realize helpers ****
def realize_buffers(real_srcs, x):
if x in real_srcs:
return realize_buffers(real_srcs, real_srcs[x]) if isinstance(real_srcs[x], LazyOp) else real_srcs[x]
return LazyOp(x.op, tuple(realize_buffers(real_srcs, y) for y in x.src), x.arg)
# **** realize functions ****
# TODO: make all _realize functions return an AST, perhaps unrealized
def _realize_loadops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], Optional[OpType]]:
if self.op.op == LoadOps.FROMCPU:
return Device._buffers[self.device].fromCPU(self.op.arg), [], LoadOps
elif self.op.op == LoadOps.CONTIGUOUS:
real_src = self.op.src[0].realize(self.device)
ret = real_src.contiguous_op()
return ret, [real_src], LoadOps if ret != real_src else None
else:
raise NotImplementedError(f"unknown LoadOp {self.op.op}")
# TODO: these two are generic, replace them?
def _realize_movementops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
real_src = self.op.src[0].realize(self.device)
return real_src.movement_op(self.op.op, self.op.arg), [real_src], MovementOps
def _realize_processingops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
real_src_x, real_src_w = [x.realize(self.device) for x in self.op.src]
return real_src_x.processing_op(self.op.op, real_src_w, self.op.arg), [real_src_x, real_src_w], ProcessingOps
# this supports late merging an upstream Elementwise op
def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
# TODO: this can also corealize a binary op after the reduce, not just before
src = self.op.src[0]
if MERGE_ELEMENTWISE_INTO_REDUCE and src.realized is None and src.optype == BinaryOps and len(src.children) <= 1:
# this is the new version, deprecate _processing_op
real_srcs : Dict[LazyBuffer, DeviceBuffer] = {x:x.realize(self.device) for x in get_buffers(src.op)}
ast = LazyOp(self.op.op, (realize_buffers(real_srcs, src.op),), self.op.arg)
return self.dbuffer.exec_ast(ast), list(real_srcs.values()), ReduceOps
else:
real_src = src.realize(self.device)
return real_src.reduce_op(self.op.op, self.op.arg), [real_src], ReduceOps
# this supports late merging an upstream Reduce op and even an Elementwise op above that
def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
real_srcs : Dict[LazyBuffer, Union[None, LazyOp, DeviceBuffer]] = {x:None for x in get_buffers(self.op)}
op_type : OpType = BinaryOps
psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype in [ProcessingOps,ReduceOps] and x.realized is None and len(x.children) <= 1 and len(k.children) <= 1]
intermediate_shape = self.shape
if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE and (self.device != "OPENCL" or self.shape[-1] == 4):
if psrcs[0][1].optype == ProcessingOps:
real_srcs[psrcs[0][0]] = psrcs[0][1].op
for x in psrcs[0][1].op.src:
real_srcs[x] = x.realize(self.device)
op_type = ProcessingOps
elif psrcs[0][1].optype == ReduceOps:
src = psrcs[0][1].op.src[0]
if MERGE_ELEMENTWISE_INTO_REDUCE and src.realized is None and src.optype == BinaryOps and len(src.children) <= 1:
src = src.op
real_srcs[psrcs[0][0]] = LazyOp(psrcs[0][1].op.op, (src,), psrcs[0][1].op.arg)
for x in get_buffers(real_srcs[psrcs[0][0]]): # type: ignore
# these are the early buffers
real_srcs[x] = x.realize(self.device)
op_type = ReduceOps
# if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs
if psrcs[0][0].shape != psrcs[0][1].shape:
intermediate_shape = psrcs[0][1].shape
assert psrcs[0][0].shape == self.shape, f"shape mismatch {psrcs[0][0].shape} != {self.shape}"
# NOTE: these RESHAPEs will return self if they don't change the shape
for x in real_srcs.keys():
if real_srcs[x] is None:
real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape).realize(self.device)
ret = self.dbuffer.exec_ast(realize_buffers(real_srcs, self.op))
return ret.movement_op(MovementOps.RESHAPE, self.shape), [x for x in real_srcs.values() if not isinstance(x, LazyOp) and x is not None], op_type
_realize = {LoadOps:_realize_loadops, ReduceOps:_realize_reduceops, MovementOps:_realize_movementops, BinaryOps:_realize_binaryops, ProcessingOps:_realize_processingops}
# **** lazy operations ****
def get_weakop(op:LazyOp) -> LazyOp: return LazyOp(op.op, tuple(get_weakop(x) if isinstance(x, LazyOp) else weakref.ref(x) for x in op.src), op.arg)
def get_movementroot(root:LazyBuffer) -> LazyBuffer: return get_movementroot(root.op.src[0]) if root.optype == MovementOps and root.realized is None else root
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot(x) if x.optype == MovementOps and x.st.contiguous else x
LAZY = int(os.getenv("LAZY", "1"))
class LazyBuffer:
lazycache : weakref.WeakValueDictionary[LazyOp, LazyBuffer] = weakref.WeakValueDictionary()
def __new__(cls, device, shape, optype, op):
# loadops aren't cached
if optype == LoadOps:
return super().__new__(cls)
wop = (device, optype, get_weakop(op)) # NOTE: shape should be deterministic. annoying to cache with the ShapeTracker
# NOTE: we need "ret" to prevent the new buffer from being immediately deleted
if wop not in LazyBuffer.lazycache:
LazyBuffer.lazycache[wop] = ret = super().__new__(cls) # noqa: F841, pylint: disable=W0612
return LazyBuffer.lazycache[wop]
def __init__(self, device, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp):
if hasattr(self, 'device'):
return # cache hit, we return and don't reinit
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
self.shape, self.optype, self.op = self.st.shape, optype, op
self.realized : Optional[DeviceBuffer] = None
self.device, self.dbuffer = device, Device._buffers[device]
self.children : weakref.WeakSet[LazyBuffer] = weakref.WeakSet()
# NOTE: op should be read only after construction of LazyBuffer
for x in get_buffers(op):
x.children.add(self)
if not LAZY:
self.realize()
def __repr__(self): return f"<LB {self.shape} op:{self.op.op if self.realized is None else 'realized'}>"
# this produces a device buffer
def realize(self:LazyBuffer, required_device=None) -> DeviceBuffer:
if required_device is not None:
assert required_device == self.device
if self.realized is None:
# we haven't realized the Buffer yet
self.realized, real_srcs, real_type = _realize[self.optype](self)
# in lazy mode, we don't log until we realize
if real_type is not None:
log_op(real_type, [x.op for x in get_lazyops(self.op)], self.realized, real_srcs)
# no need to keep the op after realization
del self.op
assert self.realized.shape == self.shape
assert isinstance(self.realized, Device._buffers[self.device])
return self.realized
@staticmethod
def fromCPU(x, device): return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x.copy()))
def toCPU(self): return self.realize().toCPU()
def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self)
def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y)
def contiguous_op(self:LazyBuffer) -> LazyBuffer: return LazyBuffer(self.device, self.shape, LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,)))
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
if self.shape == tuple(new_shape):
return self
reduce = list(enumerate(zip(self.shape, new_shape)))
# move the reduce axes to the end
x = self.movement_op(MovementOps.PERMUTE, [i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
new_tmp_shape = tuple([n for _,(s,n) in reduce if s == n] + [n for _,(s,n) in reduce if s != n])
# NOTE: this reshape can only move around 1s
return LazyBuffer(x.device, new_tmp_shape, ReduceOps, LazyOp(op, (x,), new_tmp_shape)).movement_op(MovementOps.RESHAPE, new_shape)
# syntactic sugar around PAD and SHRINK
# TODO: turn RESHAPE into EXPAND and CONTRACT (current EXPAND should be REPEAT)
def slice(self:LazyBuffer, arg):
padding = [(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg)]
return self.movement_op(MovementOps.PAD, padding).movement_op(MovementOps.SHRINK, tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)))
def movement_op(self:LazyBuffer, op:MovementOps, arg) -> LazyBuffer:
# TODO: look into why that copy is needed
arg = tuple(copy(arg))
local_st = ShapeTracker(self.shape).movement_op(op, arg)
# instant nops
if local_st.contiguous and self.shape == local_st.shape and op != MovementOps.STRIDED:
return self
# two ops in a row is one op. merge them if unresolved
if self.realized is None and self.op.op == op:
if op in [MovementOps.RESHAPE, MovementOps.EXPAND, MovementOps.SHRINK]:
return self.op.src[0].movement_op(op, arg)
if op == MovementOps.PERMUTE:
return self.op.src[0].movement_op(op, tuple(self.op.arg[i] for i in arg))
if op == MovementOps.PAD:
return self.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)))
# TODO: MovementOps.FLIP / MovementOps.STRIDED?
# some permutes are actually just reshapes
if op == MovementOps.PERMUTE and local_st.contiguous:
return self.movement_op(MovementOps.RESHAPE, tuple(self.shape[i] for i in arg))
# some strideds are actually just reshapes
# NOTE: due to how strided works, we have to check the parent to be contiguous also
if op == MovementOps.STRIDED and local_st.contiguous and self.st.contiguous:
return self.movement_op(MovementOps.RESHAPE, tuple(i for i,_ in arg))
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and len(self.children) == 0 and (SHUFFLE_PAD_OPS or op != MovementOps.PAD) and op not in [MovementOps.EXPAND, MovementOps.STRIDED]:
def replace_with_movement_op(y:Union[LazyOp, LazyBuffer]) -> LazyBuffer:
if isinstance(y, LazyBuffer):
return y.movement_op(op, arg)
assert y.op in BinaryOps or y.op in UnaryOps
return elementwise_op(y.op, *[replace_with_movement_op(z) for z in y.src]) # type: ignore
return replace_with_movement_op(self.op)
# create the buffer
ret = LazyBuffer(self.device, ShapeTracker(self.st).movement_op(op, arg), MovementOps, LazyOp(op, (self,), arg))
# if the ShapeTracker becomes contiguous, replace the whole thing with a reshape (or nothing if shapes match)
# NOTE: if ret is in the cache, it can already be realized
if REMOVE_MOVEMENT_NOPS and ret.realized is None and self.realized is None and ret.st.contiguous:
# MovementOps aren't stacked any more, they each have one parent, find the root
root = get_movementroot(self)
if root.st.contiguous and root != self and prod(ret.st.shape) == prod(root.shape):
return root.movement_op(MovementOps.RESHAPE, ret.st.shape) if ret.st.shape != root.shape else root
return ret
def processing_op(self:LazyBuffer, op:ProcessingOps, w:LazyBuffer, C:ConvArgs) -> LazyBuffer:
x = self
# TODO: fixup C?
if NOCONV or not getattr(x.dbuffer, "SUPPORTS_PADDING", False):
x = x.slice(((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
if NOCONV or not getattr(x.dbuffer, "processing_op", False):
# universal conv, just mul and reduce
# TODO: is there any way to replace strided with other movement ops?
x = x.movement_op(MovementOps.STRIDED, (
(C.bs, C.groups*C.cin*x.shape[2]*x.shape[3]), (C.groups, C.cin*x.shape[2]*x.shape[3]),
(1, 1), (C.oy, C.sy*x.shape[3]), (C.ox, C.sx),
(C.cin, x.shape[2]*x.shape[3]), (C.H, C.dy*x.shape[3]), (C.W, C.dx)))
#if C.H <= 3 and C.W <= 3: # max 9x the RAM overhead, this is im2col
# x = x.contiguous_op()
x = x.movement_op(MovementOps.EXPAND, (C.bs, C.groups, C.rcout, C.oy, C.ox, C.cin, C.H, C.W))
w = w.movement_op(MovementOps.RESHAPE, (1, C.groups, C.rcout, 1, 1, C.cin, C.H, C.W)) \
.movement_op(MovementOps.EXPAND, (C.bs, C.groups, C.rcout, C.oy, C.ox, C.cin, C.H, C.W))
return x.binary_op(BinaryOps.MUL, w).reduce_op(ReduceOps.SUM, (C.bs, C.groups, C.rcout, C.oy, C.ox, 1, 1, 1)) \
.movement_op(MovementOps.RESHAPE, (C.bs, C.cout, C.oy, C.ox))
elif x.device == "OPENCL":
# TODO: these can be properties on the device buffer
from accel.opencl.preprocessing import preprocessing_op, postprocessing_op # type: ignore
x,w,Cn = preprocessing_op(x, w, C)
ret = LazyBuffer(x.device, Cn.out_shape, ProcessingOps, LazyOp(op, (x, w), Cn))
return postprocessing_op(ret, Cn, C)
else:
return LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C))
def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer) -> LazyBuffer:
out_device, out_shape = srcs[0].device, srcs[0].shape
if MERGE_ELEMENTWISE_OPS or (MERGE_UNARY_OPS and len(set(srcs)) == 1):
# remove the buffers from any (childless) BinaryOps that feed into this
srcs = tuple(x.op if x.optype == BinaryOps and len(x.children) == 0 and x.realized is None else x for x in srcs) # type: ignore
return LazyBuffer(out_device, out_shape, BinaryOps, LazyOp(op, srcs))
class Tensor:
training, no_grad = False, False
def __init__(self, data, device=Device.DEFAULT, requires_grad=None):
if isinstance(data, list):
data = np.array(data, dtype=np.float32)
elif isinstance(data, LazyBuffer) and data.device != device:
# TODO: this has to realize, it shouldn't have to
data = data.realize().toCPU()
if isinstance(data, np.ndarray):
if data.shape == tuple():
data = data.reshape((1,))
self.lazydata = LazyBuffer.fromCPU(data.astype(np.float32), device)
elif isinstance(data, LazyBuffer):
self.lazydata = data
else:
raise Exception(f"can't create Tensor from {data}")
# tensors have gradients, buffers do not
self.grad : Optional[Tensor] = None
# NOTE: this can be in three states. False and None: no gradient, True: gradient
# None (the default) will be updated to True if it's put in an optimizer
self.requires_grad : Optional[bool] = requires_grad
# internal variables used for autograd graph construction
self._ctx : Optional[Function] = None
def __repr__(self):
return f"<Tensor {self.lazydata if self.lazydata.realized is None else self.lazydata.realized!r} with grad {(self.grad.lazydata if self.grad else None)!r}>"
@property
def shape(self): return self.lazydata.shape
# dtype handling was very broken. it's always float32 now
@property
def dtype(self): return np.float32
@property
def device(self): return self.lazydata.device
# ***** data handlers ****
def realize(self):
self.lazydata.realize()
return self
def assign(self, x):
if not isinstance(x, Tensor):
x = Tensor(x)
assert self.shape == x.shape
self.lazydata = x.lazydata
return x
def detach(self): return Tensor(self.lazydata, device=self.device, requires_grad=False)
def numpy(self): return np.array(self.lazydata.toCPU())
# TODO: this keeps the legacy behavior working, remove it after refactor
@property
def data(self): return self.numpy()
# TODO: if things are realized this won't work
def to_(self, device:str):
assert self.lazydata.realized is None
self.lazydata.device = device
if self.grad:
self.grad.lazydata.device = device
def to(self, device:str):
ret = Tensor(self.lazydata, device)
if self.grad:
ret.grad = self.grad.to(device)
return ret
# ***** creation helper functions *****
# TODO: remove use of numpy here
@classmethod
def zeros(cls, *shape, **kwargs): return cls(np.zeros(shape, dtype=np.float32), **kwargs)
@classmethod
def ones(cls, *shape, **kwargs): return cls(np.ones(shape, dtype=np.float32), **kwargs)
@classmethod
def empty(cls, *shape, **kwargs): return cls(np.empty(shape, dtype=np.float32), **kwargs)
@classmethod
def randn(cls, *shape, **kwargs): return cls(np.random.default_rng().standard_normal(size=shape, dtype=np.float32), **kwargs)
@classmethod
def arange(cls, stop, start=0, **kwargs): return cls(np.arange(start=start, stop=stop, dtype=np.float32), **kwargs)
# TODO: uniform should be a late binding thing
# Return random number between -1 and 1
# NOTE: this behavior changed from depending on the shape to not
@classmethod
def uniform(cls, *shape, **kwargs): return cls((np.random.default_rng().random(size=shape, dtype=np.float32) * 2 - 1), **kwargs)
@classmethod
def scaled_uniform(cls, *shape, **kwargs): return cls((np.random.default_rng().random(size=shape, dtype=np.float32) * 2 - 1) * (prod(shape)**-0.5), **kwargs)
@classmethod
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
def glorot_uniform(cls, *shape, **kwargs): return cls((np.random.default_rng().random(size=shape, dtype=np.float32) * 2 - 1) * ((6/(shape[0]+prod(shape[1:])))**0.5), **kwargs)
@classmethod
def eye(cls, dim, **kwargs): return cls(np.eye(dim, dtype=np.float32), **kwargs)
# ***** toposort and backward pass *****
def deepwalk(self):
def _deepwalk(node, visited, nodes):
visited.add(node)
if node._ctx:
[_deepwalk(i, visited, nodes) for i in node._ctx.parents if i not in visited]
nodes.append(node)
return nodes
return _deepwalk(self, set(), [])
def backward(self):
assert self.shape == (1,)
# fill in the first grad with one
# this is "implicit gradient creation"
self.grad = Tensor.ones(*self.shape, device=self.device, requires_grad=False)
for t0 in reversed(self.deepwalk()):
if not any(x.requires_grad for x in t0._ctx.parents):
continue
assert (t0.grad is not None)
grads = t0._ctx.backward(t0.grad.lazydata)
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
for t, g in zip(t0._ctx.parents, grads):
if g is not None and t.requires_grad:
assert g.shape == t.shape, f"grad shape must match tensor shape in {self._ctx!r}, {g.shape!r} != {t.shape!r}"
t.grad = g if t.grad is None else (t.grad + g)
del t0._ctx
# ***** non first class ops (hlops) *****
def __getitem__(self, val):
arg, new_shape = [], []
for i, rs in enumerate(val if isinstance(val, (list, tuple)) else [val]) if val is not None else []:
s = slice(rs, rs+1, None) if isinstance(rs, int) else rs
arg.append((s.start if s.start is not None else 0, (s.stop if s.stop>=0 else self.shape[i]+s.stop) if s.stop is not None else self.shape[i]))
assert s.step is None or s.step == 1
if not isinstance(rs, int): # don't include in shape if it's an int
new_shape.append(arg[-1][1] - arg[-1][0])
new_shape += [self.shape[i] for i in range(len(arg), len(self.shape))]
return self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))]).reshape(new_shape if len(new_shape) else (1,))
def cat(self, *args, dim=0):
dim = (dim + len(self.shape)) if dim < 0 else dim
for y in args:
assert len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim)
args = [self] + list(args)
shape_cumsum = [0, *itertools.accumulate(y.shape[dim] for y in args)]
slc = [[(0, s) for s in self.shape] for _ in args]
for s,k in zip(slc, shape_cumsum):
s[dim] = (-k, shape_cumsum[-1]-k)
return functools.reduce(Tensor.__iadd__, [arg.slice(arg=s) for arg,s in zip(args, slc)])
# TODO: make this nicer with syntactic sugar in slice
def chunk(self, num, dim):
slice_params = [[(0, s) for s in self.shape] for _ in range(num)]
for i,k in enumerate(range(0, self.shape[dim], self.shape[dim]//num)):
slice_params[i][dim] = (k, min(self.shape[dim], k+self.shape[dim]//num))
return [self.slice(arg=p) for p in slice_params]
def matmul(self:Tensor, w:Tensor):
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
cin, cout = w.shape[-2], w.shape[-1]
out_shape_t = tuple(list(self.shape[0:-2])+[cout,-1])
if len(self.shape) > 1:
order = tuple(list(range(len(self.shape)-2))+[len(self.shape)-1, len(self.shape)-2])
else:
order, out_shape_t = (0,), (cout, )
worder = tuple(list(range(len(w.shape)-2))+[len(w.shape)-1, len(w.shape)-2])
# NOTE: with NHWC we can remove the transposes
# bs x groups*cin x H x W
cx = self.transpose(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
# groups*cout x cin x H, W
cw = w.transpose(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).transpose(order=order)
# TODO: what's the difference between dot and matmul?
dot = matmul
# (padding_left, padding_right, padding_top, padding_bottom)
def pad2d(self, padding:Tuple[int, ...]): return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]]
# TODO: this is totally not transpose
def transpose(self, order=(1,0)): return self.permute(order=order)
def flatten(self, start_dim=0): return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))
def _reduce(self, fxn, axis=None, keepdim=False):
if axis is None:
axis = range(len(self.shape))
if isinstance(axis, int):
axis = [axis]
axis = tuple([x if x >= 0 else x+len(self.shape) for x in axis])
shape = [self.shape[i] for i in range(len(self.shape)) if i not in axis]
ret = fxn(self, axis=axis)
return ret if keepdim else ret.reshape(shape=[1] if shape == [] else shape)
def sum(self, axis=None, keepdim=False): return self._reduce(Tensor._sum, axis, keepdim)
def max(self, axis=None, keepdim=False): return self._reduce(Tensor._max, axis, keepdim)
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
def mean(self, axis=None, keepdim=False):
out = self.sum(axis=axis, keepdim=keepdim)
return out * (prod(out.shape)/prod(self.shape))
def _softmax(self):
m = self - self.max(axis=len(self.shape)-1, keepdim=True)
e = m.exp()
return m, e, e.sum(axis=len(self.shape)-1, keepdim=True)
def softmax(self):
_, e, ss = self._softmax()
return e.div(ss)
def logsoftmax(self):
m, _, ss = self._softmax()
return m - ss.log()
def dropout(self, p=0.5):
if not Tensor.training:
return self
_mask = np.asarray(np.random.binomial(1, 1.0-p, size=self.shape), dtype=self.dtype)
return self * Tensor(_mask, requires_grad=False, device=self.device) * (1/(1.0 - p))
# TODO: support arbitrary strides
def _pool2d(self, py, px):
xup = self[:, :, :self.shape[2]-self.shape[2]%py, :self.shape[3]-self.shape[3]%px] if (self.shape[2]%py != 0) or (self.shape[3]%px != 0) else self
return xup.reshape(shape=(xup.shape[0], xup.shape[1], xup.shape[2]//py, py, xup.shape[3]//px, px))
def avg_pool2d(self, kernel_size=(2,2)): return self._pool2d(*kernel_size).mean(axis=(3,5))
def max_pool2d(self, kernel_size=(2,2)): return self._pool2d(*kernel_size).max(axis=(3,5))
def conv2d(self, weight, bias=None, **kwargs):
ret = self._conv2d(weight, **kwargs)
return ret if bias is None else ret.add(bias.reshape(shape=[1, -1, 1, 1]))
# ***** math functions (unary) *****
def __neg__(self): return 0.0-self
def sqrt(self): return self.pow(0.5)
def square(self): return self*self
def clip(self, min_, max_): return ((self-min_).relu()+min_) - (self-max_).relu()
def abs(self): return self.relu() + (-self).relu()
def sign(self): return self / (self.abs() + 1e-10)
# ***** activation functions (unary) *****
def sigmoid(self): return (1.0 + (-self).exp()).reciprocal()
def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu()
def swish(self): return self * self.sigmoid()
silu = swish # The SiLU function is also known as the swish function.
def relu6(self): return self.relu() - (self-6).relu()
def hardswish(self): return self * (self+3).relu6() * (1/6)
def tanh(self): return 2.0 * ((2.0 * self).sigmoid()) - 1.0
def gelu(self): return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
def quick_gelu(self): return self * (self * 1.702).sigmoid()
def leakyrelu(self, neg_slope=0.01): return self.relu() - (-neg_slope*self).relu()
def mish(self): return self * self.softplus().tanh()
def softplus(self, limit=20, beta=1): return (1/beta) * (1 + (self*beta).exp()).log()
# ***** broadcasted binary ops *****
@staticmethod
def broadcasted(fxn, x, y):
tt = [arg for arg in [x,y] if isinstance(arg, Tensor)][0] # this is the prototype tensor
x,y = [Tensor([t], device=tt.device, requires_grad=False) if not isinstance(t, Tensor) else t for t in [x,y]]
x,y = [t.reshape([1]*(max(len(x.shape), len(y.shape))-len(t.shape)) + list(t.shape)) for t in [x,y]]
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x.shape, y.shape))
return fxn(x.expand(shape_ret), y.expand(shape_ret))
# TODO: are these the only ones that can take number arguments?
def add(self, x): return Tensor.broadcasted(Tensor._add, self, x)
def sub(self, x): return Tensor.broadcasted(Tensor._sub, self, x)
def mul(self, x): return Tensor.broadcasted(Tensor._mul, self, x)
def pow(self, x): return Tensor.broadcasted(Tensor._pow, self, x)
def div(self, y): return self * (y.reciprocal() if isinstance(y, Tensor) else (1/y))
# ***** functional nn ops *****
# TODO: fix the kwargs problem, then remove these (or not, since they now fix tuples)
def reshape(self, shape, *args): return self._reshape(shape=argfix(shape, *args))
def expand(self, shape, *args): return self._expand(shape=argfix(shape, *args))
def permute(self, order, *args): return self._permute(order=argfix(order, *args))
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
return x.add(bias) if bias is not None else x
def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return functools.reduce(lambda x,f: f(x), ll, self)
def layernorm(self, axis=-1, eps=1e-5):
y = (self - self.mean(axis=axis, keepdim=True))
return y.div((y*y).mean(axis=axis, keepdim=True).add(eps).sqrt())
# An instantiation of the Function is the Context
class Function:
def __init__(self, device:str, *tensors:Tensor):
self.device, self.parents = device, tensors
self.needs_input_grad = [t.requires_grad for t in self.parents]
self.requires_grad = True if any(self.needs_input_grad) else (None if any(x is None for x in self.needs_input_grad) else False)
self.saved_tensors : List[Tensor] = []
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
def backward(self, *args, **kwargs): raise NotImplementedError(f"backward not implemented for {type(self)}")
# NOTE: it doesn't hurt to save this since the ctx will be freed fast without grad
def save_for_backward(self, *x): self.saved_tensors.extend(x)
@classmethod
def apply(cls, *x:Tensor, **kwargs):
ctx = cls(x[0].device, *x)
ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad)
if ctx.requires_grad and not Tensor.no_grad:
ret._ctx = ctx # used by autograd engine
return ret
class ReLU(Function):
def forward(self, x):
ret = x.unary_op(UnaryOps.RELU)
self.save_for_backward(ret)
return ret
def backward(self, grad_output):
return self.saved_tensors[0].unary_op(UnaryOps.SIGN).binary_op(BinaryOps.MUL, grad_output)
class Log(Function):
def forward(self, x):
self.save_for_backward(x)
return x.unary_op(UnaryOps.LOG)