-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsynthetic2.m
executable file
·138 lines (96 loc) · 3.92 KB
/
synthetic2.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
%fixed number of source samples=50, Linear model
rng('default')
np=50; %np stands for number of source samples
nq_train=500; %nq_train stands for number of target training samples
%nq_crossval=100;
nq_test=200; %number of test target points
mu=.0001; %step size of gradient descent
k=30;
lambda=1;
number_epo=1000; %number of iterations of gradient descent
d=200;
sigma_noise=1; %noise level
sigma_source=2*eye(d); %covariance matrix of source
sigma_target=eye(d); %covariance matrix of target
M_S = normrnd(0,10,[k,d]); % M_S is the source parameter
M_T2=M_S+normrnd(0,.001,[k,d]); %M_T2 is the target_1 parameter
M_T5=M_S+60000*normrnd(0,.0001,[k,d]); %M_T5 is the target_2 parameter
n2=norm(M_S-M_T2,2);
n5=norm(M_S-M_T5,2);
trial=10;%number of trials
error_target2_trial=zeros([1,501]);
error_target5_trial=zeros([1,501]);
for trial=1:10
Ze=zeros(1,d);
Ze1=zeros(1,k);
%Generating target and source data
X_source=mvnrnd(Ze,sigma_source,np);
y_source=(M_S*X_source.'+mvnrnd(Ze1,sigma_noise*eye(k),np).').';
X_train_target2=mvnrnd(Ze,sigma_target,nq_train);
y_train_target2=(M_T2*X_train_target2.'+mvnrnd(Ze1,sigma_noise*eye(k),nq_train).').';
X_train_target5=mvnrnd(Ze,sigma_target,nq_train);
y_train_target5=(M_T5*X_train_target5.'+mvnrnd(Ze1,sigma_noise*eye(k),nq_train).').';
X_test_target2=mvnrnd(Ze,sigma_target,nq_test);
y_test_target2=(M_T2*X_test_target2.'+mvnrnd(Ze1,sigma_noise*eye(k),nq_test).').';
X_test_target5=mvnrnd(Ze,sigma_target,nq_test);
y_test_target5=(M_T5*X_test_target5.'+mvnrnd(Ze1,sigma_noise*eye(k),nq_test).').';
M_T_init= normrnd(0,1,[k,d]);
M_T2_est=M_T_init;
M_T5_est=M_T_init;
train_loss_target2=zeros([1,500+1]);
test_loss_target2=zeros([1,500+1]);
train_loss_target5=zeros([1,500+1]);
test_loss_target5=zeros([1,500+1]);
t=zeros(1,np+1);
for i=1:500+1
t(1,i)=i-1;
end
for number_q=0:500 %sweeping number of target samples
M_T2_est=M_T_init;
M_T5_est=M_T_init;
%number_p=1;
for epoch=1:number_epo %loop for gradeint descent
if number_q>0
g2=(2/number_q)*(M_T2_est*X_train_target2(1:number_q,:).'*X_train_target2(1:number_q,:)-y_train_target2(1:number_q,:).'*X_train_target2(1:number_q,:))+(2/np)*(M_T2_est*X_source.'*X_source-y_source.'*X_source);
end
if number_q==0
g2=(2/np)*(M_T2_est*X_source.'*X_source-y_source.'*X_source);
end
M_T2_est=M_T2_est-(number_q+1)*mu*g2;
if number_q>0
g5=(2/number_q)*(M_T5_est*X_train_target5(1:number_q,:).'*X_train_target5(1:number_q,:)-y_train_target5(1:number_q,:).'*X_train_target5(1:number_q,:))+(.001/np)*(M_T5_est*X_source.'*X_source-y_source.'*X_source);
end
if number_q==0
g5=(2/np)*(M_T5_est*X_source.'*X_source-y_source.'*X_source);
end
M_T5_est=M_T5_est-(number_q+1)*mu*2*g5;
end
%test error:
for i=1:nq_test
test_loss_target2(1,number_q+1)=test_loss_target2(1,number_q+1)+(1/nq_test)*norm(y_test_target2(i,:).'-M_T2_est*X_test_target2(i,:).',2)^2;
end
test_loss_target2(1,number_q+1)
%test error:
for i=1:nq_test
test_loss_target5(1,number_q+1)=test_loss_target5(1,number_q+1)+(1/nq_test)*norm(y_test_target5(i,:).'-M_T5_est*X_test_target5(i,:).',2)^2;
end
test_loss_target5(1,number_q+1);
end
error_target2_trial(1,:)=error_target2_trial(1,:)+test_loss_target2(1,:);
error_target5_trial(1,:)=error_target5_trial(1,:)+test_loss_target5(1,:);
end
error_target2_trial(1,:)=error_target2_trial(1,:)/10;
error_target5_trial(1,:)=error_target5_trial(1,:)/10;
figure(1)
plot(t,error_target2_trial)
hold on
plot(t,error_target5_trial)
legend("Small Delta","Large Dela")
title("fixed np")
xlabel("n_q")
ylabel("generalization error")
%figure(2)
%plot(t,test_loss_target5)
%legend("fixed n_p and Delta and Delta is large")
%xlabel("n_p")
%ylabel("generalization error")