forked from gorgonia/gorgonia
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexecution.go
158 lines (141 loc) · 3.89 KB
/
execution.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
package gorgonia
import (
"github.com/pkg/errors"
"gorgonia.org/tensor"
)
// Arena is a representation of a pool of tensor.Memory
type Arena interface {
Get(dev Device, size int64) (tensor.Memory, error) // Get returns a NoOpError when it cannot get a memory. Please allocate
GetFromValue(dev Device, v Value) (tensor.Memory, error) // Gets a memory and copies the values into the memory and returns it.
Put(dev Device, mem tensor.Memory, size int64) // puts the memory back into the arena
PutValue(dev Device, v Value) // puts the memory back into the arena
// Transfers memory from device to device
Transfer(toDev, fromDev Device, v Value, synchronous bool) (retVal Value, err error)
}
// External is a representation of an external device (cuda/cgo/openCL), conceptually modelled as a machine.
type External interface {
Arena
Signal() // signals the machine to do work
Sync() chan struct{}
}
// ExecutionContext informs how an op should be executed
type ExecutionContext struct {
External
Device
}
// ExternalOp is an op that contains an external context. This allows for ops to be run without needing a VM
type ExternalOp struct {
Op
ExecutionContext
Prealloc Value
Incr Value // is this a Incr? IncrDoers have higher precedence over PreallocDo
UseUnsafe bool // Is this an unsafe op? Lowest of all "special" Dos
}
// NewExternalOp creates a new *ExternalOp.
func NewExternalOp(op Op, ctx ExecutionContext, prealloc Value) *ExternalOp {
retVal := &ExternalOp{
Op: op,
ExecutionContext: ctx,
Prealloc: prealloc,
UseUnsafe: false,
}
return retVal
}
func (op *ExternalOp) DetermineDevice(inputs Nodes, output *Node) error {
dev := output.dataOn
var inDev Device = -2
var allSame bool
for _, in := range inputs {
if in.dataOn != dev {
allSame = false
}
if inDev == -2 {
inDev = in.dataOn
continue
}
if in.dataOn != inDev && in.dataOn != dev {
return errors.Errorf("Cannot automatically determine device.")
}
}
if !allSame {
return errors.Errorf("Not all the same devices")
}
op.Device = dev
return nil
}
// Do performs the op,
func (op *ExternalOp) Do(vals ...Value) (Value, error) {
if op.Device == CPU {
switch {
case op.Incr != nil:
if id, ok := op.Op.(IncrDoer); ok {
if err := id.IncrDo(op.Incr, vals...); err != nil {
if ver, ok := err.(Valuer); ok {
return ver.Value(), nil
}
return nil, err
}
return op.Incr, nil
}
case op.Prealloc != nil:
if pd, ok := op.Op.(UsePreallocDoer); ok {
pd.UsePreallocDo(op.Prealloc, vals...)
}
retVal, err := op.Op.Do(vals...)
if err != nil {
return retVal, err
}
return Copy(op.Prealloc, retVal)
case op.UseUnsafe:
if ud, ok := op.Op.(UnsafeDoer); ok {
return ud.UnsafeDo(vals...)
}
fallthrough
default:
return op.Op.Do(vals...)
}
}
switch o := op.Op.(type) {
case CUDADoer:
if op.Incr != nil {
v, err := o.CUDADo(op.External, op.Device, op.Prealloc, vals...)
if err != nil {
return nil, err
}
add := newEBOByType(addOpType, TypeOf(op.Incr), TypeOf(v))
addOp := NewExternalOp(add, op.ExecutionContext, nil)
addOp.UseUnsafe = true
retVal, err := addOp.Do(op.Incr, v)
return retVal, err
}
return o.CUDADo(op.External, op.Device, op.Prealloc, vals...)
case CLDoer:
case IncrDoer:
if op.Incr != nil {
if err := o.IncrDo(op.Incr, vals...); err != nil {
if ver, ok := err.(Valuer); ok {
return ver.Value(), nil
}
return nil, err
}
return op.Incr, nil
}
return op.Op.Do(vals...)
case UsePreallocDoer:
if op.Prealloc != nil {
return o.UsePreallocDo(op.Prealloc, vals...)
}
return op.Op.Do(vals...)
case UnsafeDoer:
if op.UseUnsafe {
return o.UnsafeDo(vals...)
}
return op.Op.Do(vals...)
default:
return o.Do(vals...)
}
panic("Unreachable")
}
func (op *ExternalOp) String() string {
return op.Op.String()
}