0001 function [ThNN, TrainingResults, TestingResults] = train(ThNN, TrainingParams, varargin)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073
0074
0075
0076
0077
0078 NumPatterns=0; NumTests=0; Verbose=0; TestingResults=[];
0079 for j=1:length(varargin)
0080 if j==length(varargin) && isscalar(varargin{j})
0081 Verbose=varargin{j};
0082 break;
0083 end
0084 if mod(j,2)==0
0085 continue;
0086 end
0087 switch (varargin{j})
0088 case 'TrainingData'
0089 if length(varargin{j+1})==2
0090 tiAll=varargin{j+1}{1};
0091 tdAll=varargin{j+1}{2};
0092
0093 TrainingResults.DEDW=0*ThNN.Weights;
0094 TrainingResults.DEDTau=TrainingResults.DEDW;
0095 TrainingResults.DEDWOld=TrainingResults.DEDW;
0096 TrainingResults.DEDTauOld=TrainingResults.DEDW;
0097 TrainingResults.NonFireFlag=0;
0098 TrainingResults.NonFireCount=[];
0099 DT=datestr(clock);
0100 DT(DT=='-' | DT==' ' | DT==':')='_';
0101 TrainingResults.DT=DT;
0102
0103 [TrainingResults.SSE, TrainingResults.RMSE, TrainingResults.ClassErr]=get_error(ThNN,tiAll,tdAll,2);
0104 elseif length(varargin{j+1})==3
0105 tiAll=varargin{j+1}{1};
0106 tdAll=varargin{j+1}{2};
0107 TrainingResults=varargin{j+1}{3};
0108 else
0109
0110 TrainingResults=-1; TestingResults=-1;
0111 disp('Error in inputs to train function');
0112 return;
0113 end
0114 NumPatterns=length(tiAll);
0115 case 'TestingData'
0116 if length(varargin{j+1})==2
0117 tiAllTest=varargin{j+1}{1};
0118 tdAllTest=varargin{j+1}{2};
0119
0120 TestingResults.DEDW=0*ThNN.Weights;
0121 TestingResults.DEDTau=TrainingResults.DEDW;
0122 TestingResults.DEDWOld=TrainingResults.DEDW;
0123 TestingResults.DEDTauOld=TrainingResults.DEDW;
0124 TestingResults.NonFireFlag=0;
0125 TestingResults.NonFireCount=[];
0126
0127 [TestingResults.SSE, TestingResults.RMSE, TestingResults.ClassErr]=get_error(ThNN,tiAllTest,tdAllTest,2);
0128 elseif length(varargin{j+1})==3
0129 tiAllTest=varargin{j+1}{1};
0130 tdAllTest=varargin{j+1}{2};
0131 TestingResults=varargin{j+1}{3};
0132 else
0133
0134 TrainingResults=-1; TestingResults=-1;
0135 disp('Error in inputs to train function');
0136 return;
0137 end
0138 NumTests=length(tiAllTest);
0139 otherwise
0140
0141 TrainingResults=-1; TestingResults=-1;
0142 disp('Error in inputs to train function');
0143 return;
0144 end
0145 end
0146
0147
0148 if strcmp(TrainingParams.Type,'Classification')
0149 Classes=[];
0150 for k=1:length(tdAll)
0151 for m=1:size(tdAll{k},1)
0152 Classes=push_unique(Classes,tdAll{k}(:,2)');
0153 end
0154 end
0155 Classes=sort(Classes);
0156 OutputNeurons=get_output_neurons(ThNN);
0157 if length(OutputNeurons)>1
0158 disp('Problem A!');
0159 end
0160 end
0161
0162
0163 C=filesep;
0164 SaveDir=[pwd,C,'Results',C,TrainingParams.DataName,C,TrainingParams.DataName,'_',num2str(TrainingResults.DT),C];
0165 if ~isdir(SaveDir)
0166 mkdir(SaveDir)
0167 end
0168
0169
0170 global Flags;
0171
0172
0173 NumInputs=length(ThNN.InputNeurons);
0174 NumOutputs=length(ThNN.OutputNeurons);
0175 NumNeurons=length(ThNN.Neurons);
0176
0177 InitialEpoch=length(TrainingResults.SSE);
0178
0179
0180 if TrainingParams.FromGUI
0181 set(TrainingParams.Handles.Current_Epoch,'String',num2str(InitialEpoch));
0182 set(TrainingParams.Handles.MSE,'String',num2str(TrainingResults.RMSE(end)));
0183 set(TrainingParams.Handles.Current_Learning_Rate,'String',num2str(TrainingParams.WeightLearningRate));
0184 end
0185
0186
0187
0188
0189 for j=InitialEpoch:TrainingParams.NumEpochs
0190
0191
0192 pause(0.01);
0193
0194
0195 if TrainingParams.FromGUI && Flags.StopFlag
0196 Flags.StopFlag=0;
0197 Flags.TrainFlag=0;
0198 set(TrainingParams.Handles.Simulation_Status,'String',['Stopped at epoch ' num2str(j)]);
0199 pause(2);
0200 set(TrainingParams.Handles.Simulation_Status,'String','Waiting...');
0201 SaveFile=[SaveDir, TrainingParams.DataName,'_',num2str(TrainingResults.DT),'_',num2str(j)];
0202 if NumTests>0
0203 save(SaveFile,'ThNN','TrainingParams','TrainingResults','TestingResults','tiOrder','tiAll','tiAllTest','tdAll','tdAllTest');
0204 else
0205 save(SaveFile,'ThNN','TrainingParams','TrainingResults','tiOrder','tiAll','tdAll');
0206 end
0207 set(TrainingParams.Handles.Continue,'UserData',SaveFile);
0208 enable_network_parameters(TrainingParams.Handles,'on');
0209 return;
0210 end
0211
0212
0213 if TrainingParams.FromGUI && Flags.PerturbFlag
0214 Flags.PerturbFlag=0;
0215 SimStatus=get(Handles.Simulation_Status,'String');
0216 set(TrainingParams.Handles.Simulation_Status,'String',['Perturbing Weights by ' num2str(TrainingParams.WeightPerturbPercent) '%']);
0217 ThNN.Weights=ThNN.Weights.*(1+0.01*TrainingParams.WeightPerturbPercent*rand(size(ThNN.Weights)));
0218 pause(2);
0219 set(TrainingParams.Handles.Simulation_Status,'String',SimStatus);
0220 end
0221
0222
0223 [ThNN, TrainingResults, tiOrder]=train_epoch(ThNN,TrainingParams,tiAll,tdAll,TrainingResults,Verbose);
0224
0225
0226 if size(TrainingResults.DEDW,1)==1 && TrainingResults.DEDW==-1
0227 disp(['Breaking at epoch ' num2str(j)]);
0228 if TrainingParams.FromGUI
0229 set(TrainingParams.Handles.Simulation_Status,'String','Output Neuron(s) Not Firing!');
0230 pause(4);
0231 set(TrainingParams.Handles.Simulation_Status,'String','Waiting...');
0232 enable_network_parameters(TrainingParams.Handles,'on');
0233 end
0234 return;
0235 end
0236
0237
0238 if TrainingParams.FromGUI && (j==1 || mod(j,TrainingParams.DisplayFrequency)==0)
0239 set(TrainingParams.Handles.Current_Epoch,'String',num2str(j));
0240 set(TrainingParams.Handles.MSE,'String',num2str(TrainingResults.RMSE(j)));
0241 set(TrainingParams.Handles.Current_Learning_Rate,'String',num2str(TrainingParams.WeightLearningRate));
0242 drawnow;
0243 end
0244
0245
0246
0247
0248
0249 if (j==1 || (NumTests==0 && mod(j,TrainingParams.DisplayFrequency)==0) || (NumTests~=0 && mod(j,TrainingParams.TestingFrequency)==0))
0250 disp(' ');
0251 disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%');
0252 disp(['EPOCH ', num2str(j)]);
0253 disp('Training Results:');
0254
0255 TrainingIndex=max(1,j-TrainingParams.DisplayFrequency+1);
0256 disp(['RMSE: ',num2str(TrainingResults.RMSE(end)),...
0257 ' with a change of ', num2str(TrainingResults.RMSE(end)-TrainingResults.RMSE(TrainingIndex)),...
0258 ' over ', num2str(min(j,(TrainingParams.DisplayFrequency))), ' epochs']);
0259 disp(['SSE: ',num2str(TrainingResults.SSE(end)),...
0260 ' with a change of ', num2str(TrainingResults.SSE(end)-TrainingResults.SSE(TrainingIndex)),...
0261 ' over ', num2str(min(j,(TrainingParams.DisplayFrequency))), ' epochs']);
0262
0263
0264 for k=1:min(10,length(tdAll))
0265 if isscalar(tiOrder)
0266 Index=1;
0267 else
0268 Index=find(tiOrder==k);
0269 end
0270
0271 for p=1:NumOutputs
0272 tsDisplay{p}{Index}=TrainingResults.tsCurrent{Index}{NumNeurons-NumOutputs+p};
0273 tdDisplay{p}{Index}=tdAll{k}(find(tdAll{k}(:,1)==(NumNeurons-NumOutputs+p)),2)';
0274
0275
0276
0277 if ~strcmp(TrainingParams.Type,'Classification') || abs(tdDisplay{p}{Index}-tsDisplay{p}{Index})>=2
0278 disp(['Neuron: ', num2str(NumNeurons-NumOutputs+p), ' Desired: ', num2str(tdDisplay{p}{Index}), ' Actual: ', num2str(tsDisplay{p}{Index}),' Diff: ', num2str(tdDisplay{p}{Index}-tsDisplay{p}{Index})]);
0279 end
0280 if length(tdDisplay{p}{Index})~=length(tsDisplay{p}{Index})
0281 disp('Error: The number of desired and actual output firing times do not match!');
0282 disp('Either an output neuron is not firing or their is an error!');
0283 return;
0284 end
0285 end
0286 end
0287
0288
0289 if strcmp(TrainingParams.Type,'Classification')
0290
0291 ClassErr=0;
0292 for k=1:NumPatterns
0293 if isscalar(tiOrder)
0294 Index=1;
0295 else
0296 Index=find(tiOrder==k);
0297 end
0298
0299 if length(TrainingResults.tsCurrent{Index}{OutputNeurons(1)})>1
0300
0301
0302 disp('Problem B!');
0303 end
0304 for m=1:length(Classes)
0305 ClassDistance=abs(TrainingResults.tsCurrent{Index}{OutputNeurons(1)}(1)-Classes(m));
0306 if m>1 && ClassDistance>ClassDistanceOld
0307 CurrentClass=Classes(m-1);
0308 break;
0309 end
0310 CurrentClass=Classes(m);
0311 ClassDistanceOld=ClassDistance;
0312 end
0313
0314 if CurrentClass~=tdAll{k}(1,2)
0315 ClassErr=ClassErr+1;
0316 end
0317 end
0318 ClassErr=ClassErr/NumPatterns;
0319 TrainingResults.ClassErr(end+1)=ClassErr;
0320 disp(['Classification Error: ', num2str(100*ClassErr), '%']);
0321 end
0322
0323
0324
0325
0326
0327 if (NumTests>0)
0328 disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%');
0329 disp('Testing Results:');
0330 [TestingResults.SSE(end+1),TestingResults.RMSE(end+1),TestingResults.ClassErr(end+1),TestingResults.tsCurrent]=get_error(ThNN,tiAllTest,tdAllTest,2);
0331
0332
0333 try
0334 TestingIndex=max(1,floor((j/TrainingParams.TestingFrequency)+1));
0335 disp(['RMSE: ',num2str(TestingResults.RMSE(end)),...
0336 ' with a change of ', num2str(TestingResults.RMSE(end)-TestingResults.RMSE(TestingIndex)),...
0337 ' over ', num2str(min(j,(TrainingParams.TestingFrequency))), ' epochs']);
0338 disp(['SSE: ',num2str(TestingResults.SSE(end)),...
0339 ' with a change of ', num2str(TestingResults.SSE(end)-TestingResults.SSE(TestingIndex)),...
0340 ' over ', num2str(min(j,(TrainingParams.TestingFrequency))), ' epochs']);
0341
0342 for k=1:min(10,length(tdAllTest))
0343 for p=1:NumOutputs
0344 tsDisplayTest{p}{k}=TestingResults.tsCurrent{k}{NumNeurons-NumOutputs+p};
0345 tdDisplayTest{p}{k}=tdAllTest{k}(find(tdAllTest{k}(:,1)==(NumNeurons-NumOutputs+p)),2)';
0346
0347
0348
0349 if ~strcmp(TrainingParams.Type,'Classification') || abs(tdDisplayTest{p}{k}-tsDisplayTest{p}{k})>=2
0350 disp(['Neuron: ', num2str(NumNeurons-NumOutputs+p), ' Desired: ', num2str(tdDisplayTest{p}{k}), ' Actual: ', num2str(tsDisplayTest{p}{k}),' Diff: ', num2str(tdDisplayTest{p}{k}-tsDisplayTest{p}{k})]);
0351 end
0352 if length(tdDisplayTest{p}{k})~=length(tsDisplayTest{p}{k})
0353 disp('Warning: Not all tests produced output spikes!');
0354 disp('Testing Error is not valid during this epoch!');
0355
0356 end
0357 end
0358 end
0359
0360 disp(['Classification Error: ', num2str(100*TestingResults.ClassErr(end)), '%']);
0361 catch
0362 disp('Not all tests produced output spikes! Testing Error is not valid during this epoch!');
0363 end
0364
0365
0366
0367 end
0368
0369
0370
0371
0372 if TrainingParams.FromGUI
0373 if NumTests>0
0374 plot_rmse(TrainingParams,TrainingResults,TestingResults);
0375 plot_extra(TrainingParams,tiOrder,tiAll,TrainingResults,tiAllTest,TestingResults);
0376 else
0377 plot_rmse(TrainingParams,TrainingResults);
0378 plot_extra(TrainingParams,tiOrder,tiAll,TrainingResults);
0379 end
0380 if strfind(TrainingParams.Type,'Classification')
0381 set(TrainingParams.Handles.CrispCE,'String',num2str(TrainingResults.ClassErr(end)));
0382 else
0383 set(TrainingParams.Handles.CrispCE,'String','N/A');
0384 end
0385 drawnow;
0386
0387
0388 if TrainingParams.MakeMovie==1
0389 if isfield(TrainingParams,'M')
0390 TrainingResults.M{1}(end+1) = getframe(TrainingParams.Handles.axes4);
0391 TrainingResults.M{2}(end+1) = getframe(TrainingParams.Handles.axes1);
0392 else
0393 TrainingResults.M{1}(1) = getframe(TrainingParams.Handles.axes4);
0394 TrainingResults.M{2}(1) = getframe(TrainingParams.Handles.axes1);
0395 end
0396 end
0397 end
0398
0399
0400 if TrainingResults.NonFireFlag
0401 disp('Warning! One or more neurons are not firing!');
0402 disp('The count of non-firing instances by neuron index for an "epoch":');
0403 TrainingResults.NonFireCount
0404 disp(' ');
0405 end
0406 disp(' ');
0407
0408
0409 disp('Saving File...');
0410 SaveFile=[SaveDir, TrainingParams.DataName,'_',num2str(TrainingResults.DT),'_',num2str(j)];
0411 if NumTests>0
0412 save(SaveFile,'ThNN','TrainingParams','TrainingResults','TestingResults','tiOrder','tiAll','tiAllTest','tdAll','tdAllTest');
0413 else
0414 save(SaveFile,'ThNN','TrainingParams','TrainingResults','tiOrder','tiAll','tdAll');
0415 end
0416 disp('Done Saving File!');
0417 end
0418
0419
0420 if (TrainingResults.RMSE(end)<TrainingParams.MaxError)
0421 disp(['Desired error acheived! Training has completed in ', num2str(j), ' epochs']);
0422 break;
0423 end
0424
0425 end
0426
0427
0428
0429 if TrainingParams.FromGUI
0430 if NumTests>0
0431 plot_rmse(TrainingParams,TrainingResults,TestingResults);
0432 plot_extra(TrainingParams,tiOrder,tiAll,TrainingResults,tiAllTest,TestingResults);
0433 else
0434 plot_rmse(TrainingParams,TrainingResults);
0435 plot_extra(TrainingParams,tiOrder,tiAll,TrainingResults);
0436 end
0437 end
0438
0439
0440 SaveFile=[SaveDir, TrainingParams.DataName,'_',num2str(TrainingResults.DT),'_',num2str(j)];
0441 if NumTests>0
0442 save(SaveFile,'ThNN','TrainingParams','TrainingResults','TestingResults','tiOrder','tiAll','tiAllTest','tdAll','tdAllTest');
0443 else
0444 save(SaveFile,'ThNN','TrainingParams','TrainingResults','tiOrder','tiAll','tdAll');
0445 end
0446
0447
0448
0449 if TrainingParams.FromGUI
0450 set(TrainingParams.Handles.Continue,'UserData',SaveFile);
0451 set(TrainingParams.Handles.Simulation_Status,'String',['Training completed at epoch ' num2str(j-1)]);
0452 pause(5);
0453 Flags.TrainFlag=0;
0454 set(TrainingParams.Handles.Simulation_Status,'String','Waiting...');
0455 enable_network_parameters(TrainingParams.Handles,'on');
0456 end
0457
0458 return;