0001 function [TrainingResults, ts]=calculate_gradient(ThNN,TrainingParams,tiCurrent,tdCurrent,Verbose)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051 TrainingResults.DEDW=-1; TrainingResults.DEDTau=-1; TrainingResults.SSE=-1;
0052 ts=-1;
0053
0054 if nargin<4
0055 disp('Error in calculate_gradient: Not enough input arguements');
0056 disp(['Needed at least 4 inputs but only got ' num2str(nargin)]);
0057 return;
0058 end
0059 if nargin<5
0060 Verbose=0;
0061 end
0062
0063 if (Verbose==1)
0064 disp('*** Entering calculate_gradient ***');
0065 disp(' ');
0066 end
0067 NumericalGradient=TrainingParams.NumericalGradient;
0068 GradientStep=TrainingParams.GradientStep;
0069 NoiseLevel=TrainingParams.NoiseLevel;
0070 DEDW=zeros(size(ThNN.Weights));
0071 DEDTau=zeros(size(ThNN.Delays));
0072 tiCurrent(:,2)=tiCurrent(:,2)+NoiseLevel*(1+randn(size(tiCurrent,1),1));
0073
0074
0075 [Temp1,Temp2]=sort(tdCurrent(:,2));
0076 tdCurrent=[tdCurrent(Temp2,1) Temp1];
0077
0078
0079 RelativeInputNeurons=ThNN.RelativeInputNeurons;
0080 RelativeOutputNeurons=ThNN.RelativeOutputNeurons;
0081
0082 GradFlag=1;
0083 MaxSimTime=1000;
0084 [ts, Y, Z] = run_network(ThNN,tiCurrent,MaxSimTime,GradFlag,Verbose);
0085 K2QAll=relate_q_to_k(ThNN,ts);
0086
0087
0088 NeuronQueue=get_output_neurons(ThNN);
0089
0090
0091 TrainingResults.NonFireFlag=0;
0092 TrainingResults.NonFireCount=zeros(1,max(size(ThNN.Neurons)));
0093 for j=1:max(size(ThNN.Neurons))
0094 if isempty(ts{j})
0095 TrainingResults.NonFireFlag=1;
0096 TrainingResults.NonFireCount(j)=1;
0097 end
0098 end
0099
0100 if sum(TrainingResults.NonFireCount(NeuronQueue)>0)
0101 TrainingParams
0102 TrainingResults
0103 disp(['Error! Neurons (including output neurons) with the following indices are not firing: ', num2str(find(TrainingResults.NonFireCount>0))]);
0104 TrainingResults.DEDW=-1;
0105 TrainingResults.DEDTau=-1;
0106 return;
0107 end
0108
0109
0110 del=cell(length(ThNN.Neurons),1);
0111 for j=1:length(NeuronQueue)
0112 CurrentNeuron=NeuronQueue(j);
0113
0114 td=sort(tdCurrent(find(tdCurrent(:,1)==CurrentNeuron),2));
0115
0116 if length(td)==length(ts{CurrentNeuron})
0117 del{CurrentNeuron}=ts{CurrentNeuron}-td';
0118 elseif length(td)<length(ts{CurrentNeuron})
0119 ts{CurrentNeuron}=ts{CurrentNeuron}(1:length(td));
0120 del{CurrentNeuron}=ts{CurrentNeuron}(1:length(td))-td';
0121 elseif length(ts{CurrentNeuron})<length(td) && ThNN.Neurons(CurrentNeuron).Io>0
0122 for k=(length(ts{CurrentNeuron})+1):length(td)
0123 ts{CurrentNeuron}(end+1)=ts{CurrentNeuron}(end)+(pi/ThNN.Neurons(CurrentNeuron).Beta);
0124 end
0125 del{CurrentNeuron}=ts{CurrentNeuron}-td';
0126 else
0127
0128 disp(['Error! Output Layer Io<0 and there are not enough output spikes being generated',...
0129 ' by output neuron ', num2str(CurrentNeuron)]);
0130 return;
0131 end
0132 end
0133
0134
0135
0136
0137 CurrentSSE=0;
0138 for p=1:length(NeuronQueue)
0139 CurrentNeuron=NeuronQueue(p);
0140 CurrentSSE=CurrentSSE+0.5*sum((del{CurrentNeuron}).^2);
0141 if size(Z{CurrentNeuron},2)<size(del{CurrentNeuron},2)
0142 for j=(size(Z{CurrentNeuron},2)+1):size(del{CurrentNeuron},2)
0143 Z{CurrentNeuron}(:,j)=Z{CurrentNeuron}(:,end);
0144 Y{CurrentNeuron}(:,j)=Y{CurrentNeuron}(:,end);
0145 end
0146 end
0147 if size(Z{CurrentNeuron},2)>size(del{CurrentNeuron},2)
0148 Z{CurrentNeuron}=Z{CurrentNeuron}(:,1:size(del{CurrentNeuron},2));
0149 Y{CurrentNeuron}=Y{CurrentNeuron}(:,1:size(del{CurrentNeuron},2));
0150 end
0151
0152 end
0153 for p=1:length(NeuronQueue)
0154 CurrentNeuron=NeuronQueue(p);
0155
0156 CurrentInputNeurons=RelativeInputNeurons{CurrentNeuron};
0157
0158 for m=1:length(CurrentInputNeurons)
0159 K2Q=K2QAll{CurrentInputNeurons(m),CurrentNeuron};
0160 if NumericalGradient
0161 DEDW(CurrentInputNeurons(m),CurrentNeuron)=numerical_gradient(ThNN,...
0162 CurrentInputNeurons(m),CurrentNeuron,tiCurrent,tdCurrent,GradientStep,'Weights');
0163 else
0164
0165
0166 DEDW(CurrentInputNeurons(m),CurrentNeuron)=sum(del{CurrentNeuron}.*Y{CurrentNeuron}(CurrentInputNeurons(m),:));
0167 end
0168 if NumericalGradient && TrainingParams.DelayEnable
0169 DEDTau(CurrentInputNeurons(m),CurrentNeuron)=numerical_gradient(ThNN,...
0170 CurrentInputNeurons(m),CurrentNeuron,tiCurrent,tdCurrent,GradientStep,'Delays');
0171 elseif TrainingParams.DelayEnable
0172
0173
0174 DEDTau(CurrentInputNeurons(m),CurrentNeuron)=sum(del{CurrentNeuron}.*sum(Z{CurrentNeuron}(K2Q,:),1));
0175 end
0176
0177 if (Verbose==1)
0178 disp(['Gradient Between Neurons ' num2str(CurrentInputNeurons(m)) ' and ' num2str(CurrentNeuron)]);
0179 disp(['Weight Calculated: ' num2str(dot(del{CurrentNeuron},Y{CurrentNeuron}(CurrentInputNeurons(m),:)))]);
0180 disp(['Weight Numerical: ' num2str(numerical_gradient(ThNN,CurrentInputNeurons(m),CurrentNeuron,tiCurrent,tdCurrent,GradientStep,'Weights'))]);
0181 if TrainingParams.DelayEnable
0182 disp(['Delay Calculated: ' num2str(dot(del{CurrentNeuron},sum(Z{CurrentNeuron}(K2Q,:),1)))]);
0183 disp(['Delay Numerical: ' num2str(numerical_gradient(ThNN,CurrentInputNeurons(m),CurrentNeuron,tiCurrent,tdCurrent,GradientStep,'Delays'))]);
0184 end
0185 end
0186 end
0187 end
0188
0189
0190
0191
0192
0193
0194
0195 for j=1:length(NeuronQueue)
0196 [NeuronQueue, CurrentNeuron]=pop(NeuronQueue);
0197 NeuronQueue=push_unique(NeuronQueue,RelativeInputNeurons{CurrentNeuron});
0198 end
0199
0200
0201 InputNeurons=get_input_neurons(ThNN);
0202
0203
0204 for sd=1:length(InputNeurons)
0205 NeuronQueue=NeuronQueue(~(NeuronQueue==InputNeurons(sd)));
0206 end
0207
0208
0209
0210
0211
0212
0213
0214 while ~isempty(NeuronQueue)
0215
0216 CurrentNeuron=NeuronQueue(1);
0217
0218 CurrentOutputNeurons=RelativeOutputNeurons{CurrentNeuron};
0219
0220 NumberOfSpikes=length(ts{CurrentNeuron});
0221
0222 for p=1:length(CurrentOutputNeurons)
0223
0224 K2Q=K2QAll{CurrentNeuron,CurrentOutputNeurons(p)};
0225 for k=1:NumberOfSpikes
0226
0227
0228 del{CurrentNeuron}(k)=sum(del{CurrentOutputNeurons(p)}.*Z{CurrentOutputNeurons(p)}(K2Q(k),:));
0229 end
0230 end
0231
0232
0233
0234
0235 CurrentInputNeurons=RelativeInputNeurons{CurrentNeuron};
0236 for p=1:length(CurrentInputNeurons)
0237 if ~isempty(del{CurrentNeuron})
0238 K2Q=K2QAll{CurrentInputNeurons(p),CurrentNeuron};
0239 if NumericalGradient
0240 DEDW(CurrentInputNeurons(p),CurrentNeuron)=numerical_gradient(ThNN,...
0241 CurrentInputNeurons(p),CurrentNeuron,tiCurrent,tdCurrent,GradientStep,'Weights');
0242 else
0243
0244
0245 DEDW(CurrentInputNeurons(p),CurrentNeuron)=sum(del{CurrentNeuron}.*Y{CurrentNeuron}(CurrentInputNeurons(p),:));
0246 end
0247 if NumericalGradient && TrainingParams.DelayEnable
0248 DEDTau(CurrentInputNeurons(p),CurrentNeuron)=numerical_gradient(ThNN,...
0249 CurrentInputNeurons(p),CurrentNeuron,tiCurrent,tdCurrent,GradientStep,'Delays');
0250 elseif TrainingParams.DelayEnable
0251
0252
0253 DEDTau(CurrentInputNeurons(p),CurrentNeuron)=sum(del{CurrentNeuron}.*sum(Z{CurrentNeuron}(K2Q,:),1));
0254 end
0255
0256 if (Verbose==1)
0257 disp(['Gradient Between Neurons ' num2str(CurrentInputNeurons(p)) ' and ' num2str(CurrentNeuron)]);
0258 disp(['Weight Calculated: ' num2str(dot(del{CurrentNeuron},Y{CurrentNeuron}(CurrentInputNeurons(p),:)))]);
0259 disp(['Weight Numerical: ' num2str(numerical_gradient(ThNN,CurrentInputNeurons(p),CurrentNeuron,tiCurrent,tdCurrent,GradientStep,'Weights'))]);
0260 if TrainingParams.DelayEnable
0261 disp(['Delay Calculated: ' num2str(dot(del{CurrentNeuron},sum(Z{CurrentNeuron}(K2Q,:),1)))]);
0262 disp(['Delay Numerical: ' num2str(numerical_gradient(ThNN,CurrentInputNeurons(p),CurrentNeuron,tiCurrent,tdCurrent,GradientStep,'Delays'))]);
0263 end
0264 end
0265 end
0266 end
0267
0268
0269
0270 NeuronQueue=pop(NeuronQueue);
0271
0272
0273 Temp=CurrentInputNeurons;
0274 for n=1:length(InputNeurons)
0275 Temp=Temp(~(Temp==InputNeurons(n)));
0276 end
0277
0278
0279 NeuronQueue=push_unique(NeuronQueue,Temp);
0280 end
0281
0282
0283 TrainingResults.DEDW=DEDW;
0284 TrainingResults.DEDTau=DEDTau;
0285 TrainingResults.SSE=CurrentSSE;
0286
0287
0288 if (Verbose==1)
0289 disp(' ');
0290 disp('Gradient Calculation Results:');
0291 del
0292 DEDW
0293 DEDTau
0294 disp(' ');
0295 end
0296
0297 return;