-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathactrf.go
164 lines (143 loc) · 4.91 KB
/
actrf.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package actrf
//go:generate core generate -add-types
import (
"slices"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/tensor"
)
// RF is used for computing an activation-based receptive field.
// It simply computes the activation weighted average of other
// *source* patterns of activation -- i.e., sum(act * src) / sum(src)
// which then shows you the patterns of source activity for which
// a given unit was active.
// You must call Init to initialize everything, Reset to restart the accumulation of the data,
// and Avg to compute the resulting averages based an accumulated data.
// Avg does not erase the accumulated data so it can continue beyond that point.
type RF struct {
// name of this RF -- used for management of multiple in RFs
Name string
// computed receptive field, as SumProd / SumSrc -- only after Avg has been called
RF tensor.Float32 `display:"no-inline"`
// unit normalized version of RF per source (inner 2D dimensions) -- good for display
NormRF tensor.Float32 `display:"no-inline"`
// normalized version of SumSrc -- sum of each point in the source -- good for viewing the completeness and uniformity of the sampling of the source space
NormSrc tensor.Float32 `display:"no-inline"`
// sum of the products of act * src
SumProd tensor.Float32 `display:"no-inline"`
// sum of the sources (denomenator)
SumSrc tensor.Float32 `display:"no-inline"`
// temporary destination sum for MPI -- only used when MPISum called
MPITmp tensor.Float32 `display:"no-inline"`
}
// Init initializes this RF based on name and shapes of given
// tensors representing the activations and source values.
func (af *RF) Init(name string, act, src tensor.Tensor) {
af.Name = name
af.InitShape(act, src)
af.Reset()
}
// InitShape initializes shape for this RF based on shapes of given
// tensors representing the activations and source values.
// does nothing if shape is already correct.
// return shape ints
func (af *RF) InitShape(act, src tensor.Tensor) []int {
aNy, aNx, _, _ := tensor.Projection2DShape(act.Shape(), false)
sNy, sNx, _, _ := tensor.Projection2DShape(src.Shape(), false)
oshp := []int{aNy, aNx, sNy, sNx}
if slices.Equal(af.RF.Shape().Sizes, oshp) {
return oshp
}
sshp := []int{sNy, sNx}
af.RF.SetShapeSizes(oshp...)
af.NormRF.SetShapeSizes(oshp...)
af.SumProd.SetShapeSizes(oshp...)
af.NormSrc.SetShapeSizes(sshp...)
af.SumSrc.SetShapeSizes(sshp...)
af.ConfigView(&af.RF)
af.ConfigView(&af.NormRF)
af.ConfigView(&af.SumProd)
af.ConfigView(&af.NormSrc)
af.ConfigView(&af.SumSrc)
return oshp
}
// ConfigView configures the view params on the tensor
func (af *RF) ConfigView(tsr *tensor.Float32) {
// todo:meta
// tsr.SetMetaData("colormap", "Viridis")
// tsr.SetMetaData("grid-fill", "1") // remove extra lines
// tsr.SetMetaData("fix-min", "true")
// tsr.SetMetaData("min", "0")
}
// Reset reinitializes the Sum accumulators -- must have called Init first
func (af *RF) Reset() {
af.SumProd.SetZeros()
af.SumSrc.SetZeros()
}
// Add adds one sample based on activation and source tensor values.
// these must be of the same shape as used when Init was called.
// thr is a threshold value on sources below which values are not added (prevents
// numerical issues with very small numbers)
func (af *RF) Add(act, src tensor.Tensor, thr float32) {
shp := af.InitShape(act, src) // ensure
aNy, aNx, sNy, sNx := shp[0], shp[1], shp[2], shp[3]
for sy := 0; sy < sNy; sy++ {
for sx := 0; sx < sNx; sx++ {
tv := float32(tensor.Projection2DValue(src, false, sy, sx))
if tv < thr {
continue
}
af.SumSrc.SetAdd(tv, sy, sx)
for ay := 0; ay < aNy; ay++ {
for ax := 0; ax < aNx; ax++ {
av := float32(tensor.Projection2DValue(act, false, ay, ax))
af.SumProd.SetAdd(av*tv, ay, ax, sy, sx)
}
}
}
}
}
// Avg computes RF as SumProd / SumSrc. Does not Reset sums.
func (af *RF) Avg() {
aNy := af.SumProd.DimSize(0)
aNx := af.SumProd.DimSize(1)
sNy := af.SumProd.DimSize(2)
sNx := af.SumProd.DimSize(3)
var maxSrc float32
for sy := 0; sy < sNy; sy++ {
for sx := 0; sx < sNx; sx++ {
src := af.SumSrc.Value(sy, sx)
if src == 0 {
continue
}
if src > maxSrc {
maxSrc = src
}
for ay := 0; ay < aNy; ay++ {
for ax := 0; ax < aNx; ax++ {
oo := af.SumProd.Shape().IndexTo1D(ay, ax, sy, sx)
af.RF.Values[oo] = af.SumProd.Values[oo] / src
}
}
}
}
if maxSrc == 0 {
maxSrc = 1
}
for i, v := range af.SumSrc.Values {
af.NormSrc.Values[i] = v / maxSrc
}
}
// Norm computes unit norm of RF values -- must be called after Avg
func (af *RF) Norm() {
stats.UnitNormOut(&af.RF, &af.NormRF)
}
// AvgNorm computes RF as SumProd / SumTarg and then does Norm.
// This is what you typically want to call before viewing RFs.
// Does not Reset sums.
func (af *RF) AvgNorm() {
af.Avg()
af.Norm()
}