-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsynthetic5.m
executable file
·175 lines (123 loc) · 3.97 KB
/
synthetic5.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
%fixed number of source samples=300 and number of target samples=20,
%linear model
rng('default')
np=300; %np stands for number of source samples
nq_train=20; %nq_train stands for number of target training samples
nq_test=200; %number of test target points
nq_val=50; %nq_val stands for number of target validation set
mu=.038; %step size of gradient descent
k=30;
lambda=1;
number_epo=6500; %number of iterations of gradient descent
d=50;
Ze=zeros(1,d);
Ze1=zeros(1,k);
sigma_noise=.3; %noise level
sigma_source=2*eye(d); %covariance matrix of source
sigma_target=eye(d); %covariance matrix of target
M_T5 = normrnd(0,10,[k,d]); %M_T5 is the target parameter
MM=normrnd(0,.0001,[k,d]); % source parameter=target parameter + c* MM
ff=zeros([350,2]);
ll=0;
%Generating target data
X_train_target5=mvnrnd(Ze,sigma_target,nq_train);
y_train_target5=(M_T5*X_train_target5.'+mvnrnd(Ze1,sigma_noise*eye(k),nq_train).').';
X_val_target5=mvnrnd(Ze,sigma_target,nq_val);
y_val_target5=(M_T5*X_val_target5.'+mvnrnd(Ze1,sigma_noise*eye(k),nq_val).').';
X_test_target5=mvnrnd(Ze,sigma_target,nq_test);
y_test_target5=(M_T5*X_test_target5.'+mvnrnd(Ze1,sigma_noise*eye(k),nq_test).').';
for kk=1:400:140000 %sweeping transfer distance
ll=ll+1;
M_S=M_T5+kk*MM; % M_S is the source parameter
EST_M1=zeros([k,d]);
EST_M2=zeros([k,d]);
EST_M3=zeros([k,d]);
EST_M4=zeros([k,d]);
EST_M5=zeros([k,d]);
ff(ll,1)=norm(M_S-M_T5,2);
for trial=1:20
%generating source data
X_source=mvnrnd(Ze,sigma_source,np);
y_source=(M_S*X_source.'+mvnrnd(Ze1,sigma_noise*eye(k),np).').';
M_T_init= zeros([k,d]);
M_T5_est=M_T_init;
train_loss_target2=zeros([1,np+1]);
test_loss_target2=zeros([1,np+1]);
train_loss_target5=zeros([1,np+1]);
test_loss_target5=zeros([1,np+1]);
t=zeros(1,np+1);
for i=1:np+1
t(1,i)=i-1;
end
gg=zeros([1,5]);
for yy=1:5 %tunning the weight of empirical risk using validation data
if yy==1
cc=1;
end
if yy==2
cc=.75;
end
if yy==3
cc=.5;
end
if yy==4
cc=.25;
end
if yy==5
cc=0;
end
M_T5_est=M_T_init;
for epoch=1:650 %loop for gradeint descent
g5=(2/nq_train)*(M_T5_est*X_train_target5.'*X_train_target5-y_train_target5.'*X_train_target5)+(2*cc/np)*(M_T5_est*X_source.'*X_source-y_source.'*X_source);
M_T5_est=M_T5_est-mu*g5;
end
%testing on validation set
for i=1:nq_val
gg(1,yy)=gg(1,yy)+(1/nq_val)*norm(y_val_target5(i,:).'-M_T5_est*X_val_target5(i,:).',2)^2;
end
if yy==1
EST_M1=M_T5_est;
end
if yy==2
EST_M2=M_T5_est;
end
if yy==3
EST_M3=M_T5_est;
end
if yy==4
EST_M4=M_T5_est;
end
if yy==5
EST_M5=M_T5_est;
end
end
%finding the optimal weight of empirical risk using validation set
if gg(1,1)<gg(1,2) && gg(1,1)<gg(1,3) && gg(1,1)<gg(1,4) && gg(1,1)<gg(1,5)
M_T5_est=EST_M1;
end
if gg(1,2)<gg(1,1) && gg(1,2)<gg(1,3) && gg(1,2)<gg(1,4) && gg(1,2)<gg(1,5)
M_T5_est=EST_M2;
end
if gg(1,3)<gg(1,1) && gg(1,3)<gg(1,2) && gg(1,3)<gg(1,4) && gg(1,3)<gg(1,5)
M_T5_est=EST_M3;
end
if gg(1,4)<gg(1,1) && gg(1,4)<gg(1,2) && gg(1,4)<gg(1,3) && gg(1,4)<gg(1,5)
M_T5_est=EST_M4;
end
if gg(1,5)<gg(1,1) && gg(1,5)<gg(1,2) && gg(1,5)<gg(1,3) && gg(1,5)<gg(1,4)
M_T5_est=EST_M5;
end
%test error
for i=1:nq_test
ff(ll,2)=ff(ll,2)+(1/nq_test)*norm(y_test_target5(i,:).'-M_T5_est*X_test_target5(i,:).',2)^2;
end
end
ff(ll,2)=ff(ll,2)/20;
end
bb=sortrows(ff,1); %sorting based on the transfer distance
figure(6)
plot(bb(:,1),bb(:,2))
hold on
title("np=300 and nq=20")
xlabel("Delta")
ylabel("generalization error")