Home > TNNT_1_07 > @theta_neuron_network > train.m

train

PURPOSE ^

TRAIN trains a theta neuron network given spike time inputs

SYNOPSIS ^

function [ThNN, TrainingResults, TestingResults] = train(ThNN, TrainingParams, varargin)

DESCRIPTION ^

TRAIN trains a theta neuron network given spike time inputs

Description:
Function to train a theta neuron network over multiple epochs. Inputs are
given in terms of spike times.  Intermediate training results may be
passed in. This function includes option interfacing with the training GUI
as well.

Syntax:
TRAIN(ThNN,TrainingParams,'TrainingData',{tiAll,tdAll,[TrainingResults]});
TRAIN(...,'TestingData',{tiAllTest,tdAllTest,[TestingResults]});
TRAIN(...,Verbose);
[ThNN, TrainingResults, TestingResults] = TRAIN(...);

Input Parameters:
o ThNN: An object of the theta neuron network class
o TrainingParams: A structure that contains information for training a
    theta neuron network, such as the learning method. This structure is
    generated by get_training_params.
o tiAll: A cell array of length a, the number of input patterns, of which
    each cell is a qx2 array that contains the neuron indices for each
    spike time and the input spike times.  q may vary from cell to cell.
o tdAll: A cell array of length a, the number of input patterns, of which
    each cell is a (qo x 2) array that contains the output neuron indices
    for each desired output spike time and the desired output spike times.
    qo may vary from cell to cell.
o TrainingResults: A structure containing error gradient with respect to
    the weights (DEDW) and delays (DEDTau).  The structure also includes
    the SSE, NonFireFlag to indicate is any neurons are not firing along
    with NonFireCount, an array used to keep track of how many times a
    neuron has not fired over multiple input patterns.  This structure is
    generated by calculate_gradient. Passed as an input argument for
    appending results.
o tiAllTest: A cell array of input test patterns in the same format as 
    tiAll.
o tdAll: A cell array of desired output test patterns in the same format 
    as tdAll.
o TestingResults: A structure containing results generated on the testing
    data in the same format as TrainingResults. Passed as an input
    argument for appending results.
o Verbose: Optional flag to indicate if extra information will be
    displayed on the screen. A value of 0 displays no additional
    information (this is the default value), while a value of 1 displays
    all information.  Values greater than 1 display partial information.
    The default value is 0. See Verbose for more details.

Output Parameters:
o ThNN: An object of the theta neuron network class, now with weights and
    delays updated by gradient-based results in TrainingResults.
o TrainingResults: A structure containing error gradient with respect to
    the weights (DEDW) and delays (DEDTau).  The structure also includes
    the SSE, NonFireFlag to indicate is any neurons are not firing along
    with NonFireCount, an array used to keep track of how many times a
    neuron has not fired over multiple input patterns.  This structure is
    generated by calculate_gradient. If passed as an input, certain
    results are appended.
o TestingResults: A structure containing results generated on the testing
    data in the same format as TrainingResults.

Examples:
>> %This example is a simple inverter trained to default RMSE
>> ThNN=theta_neuron_network;
>> TrainingParams=get_training_params;
>> [ThNN, TrainingResults]=train(ThNN,TrainingParams,'TrainingData',...
     {{[2 3],[2 6]},{[3 25],[3 20]}});
>> TrainingResults.RMSE(end)
>> figure;
>> plot(TrainingResults.RMSE);
>> xlabel('Epochs'); ylabel('RMSE');

See also theta_neuron_network, verbose

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [ThNN, TrainingResults, TestingResults] = train(ThNN, TrainingParams, varargin)
0002 %TRAIN trains a theta neuron network given spike time inputs
0003 %
0004 %Description:
0005 %Function to train a theta neuron network over multiple epochs. Inputs are
0006 %given in terms of spike times.  Intermediate training results may be
0007 %passed in. This function includes option interfacing with the training GUI
0008 %as well.
0009 %
0010 %Syntax:
0011 %TRAIN(ThNN,TrainingParams,'TrainingData',{tiAll,tdAll,[TrainingResults]});
0012 %TRAIN(...,'TestingData',{tiAllTest,tdAllTest,[TestingResults]});
0013 %TRAIN(...,Verbose);
0014 %[ThNN, TrainingResults, TestingResults] = TRAIN(...);
0015 %
0016 %Input Parameters:
0017 %o ThNN: An object of the theta neuron network class
0018 %o TrainingParams: A structure that contains information for training a
0019 %    theta neuron network, such as the learning method. This structure is
0020 %    generated by get_training_params.
0021 %o tiAll: A cell array of length a, the number of input patterns, of which
0022 %    each cell is a qx2 array that contains the neuron indices for each
0023 %    spike time and the input spike times.  q may vary from cell to cell.
0024 %o tdAll: A cell array of length a, the number of input patterns, of which
0025 %    each cell is a (qo x 2) array that contains the output neuron indices
0026 %    for each desired output spike time and the desired output spike times.
0027 %    qo may vary from cell to cell.
0028 %o TrainingResults: A structure containing error gradient with respect to
0029 %    the weights (DEDW) and delays (DEDTau).  The structure also includes
0030 %    the SSE, NonFireFlag to indicate is any neurons are not firing along
0031 %    with NonFireCount, an array used to keep track of how many times a
0032 %    neuron has not fired over multiple input patterns.  This structure is
0033 %    generated by calculate_gradient. Passed as an input argument for
0034 %    appending results.
0035 %o tiAllTest: A cell array of input test patterns in the same format as
0036 %    tiAll.
0037 %o tdAll: A cell array of desired output test patterns in the same format
0038 %    as tdAll.
0039 %o TestingResults: A structure containing results generated on the testing
0040 %    data in the same format as TrainingResults. Passed as an input
0041 %    argument for appending results.
0042 %o Verbose: Optional flag to indicate if extra information will be
0043 %    displayed on the screen. A value of 0 displays no additional
0044 %    information (this is the default value), while a value of 1 displays
0045 %    all information.  Values greater than 1 display partial information.
0046 %    The default value is 0. See Verbose for more details.
0047 %
0048 %Output Parameters:
0049 %o ThNN: An object of the theta neuron network class, now with weights and
0050 %    delays updated by gradient-based results in TrainingResults.
0051 %o TrainingResults: A structure containing error gradient with respect to
0052 %    the weights (DEDW) and delays (DEDTau).  The structure also includes
0053 %    the SSE, NonFireFlag to indicate is any neurons are not firing along
0054 %    with NonFireCount, an array used to keep track of how many times a
0055 %    neuron has not fired over multiple input patterns.  This structure is
0056 %    generated by calculate_gradient. If passed as an input, certain
0057 %    results are appended.
0058 %o TestingResults: A structure containing results generated on the testing
0059 %    data in the same format as TrainingResults.
0060 %
0061 %Examples:
0062 %>> %This example is a simple inverter trained to default RMSE
0063 %>> ThNN=theta_neuron_network;
0064 %>> TrainingParams=get_training_params;
0065 %>> [ThNN, TrainingResults]=train(ThNN,TrainingParams,'TrainingData',...
0066 %     {{[2 3],[2 6]},{[3 25],[3 20]}});
0067 %>> TrainingResults.RMSE(end)
0068 %>> figure;
0069 %>> plot(TrainingResults.RMSE);
0070 %>> xlabel('Epochs'); ylabel('RMSE');
0071 %
0072 %See also theta_neuron_network, verbose
0073 
0074 %Copyright (C) 2008 Sam McKennoch <Samuel.McKennoch@loria.fr>
0075 
0076 
0077 %Go through inputs and make sure properties are {'TrainingData', 'TestingData'}
0078 NumPatterns=0; NumTests=0; Verbose=0; TestingResults=[];
0079 for j=1:length(varargin)
0080     if j==length(varargin) && isscalar(varargin{j})
0081         Verbose=varargin{j};
0082         break;
0083     end
0084     if mod(j,2)==0 
0085         continue; 
0086     end
0087     switch (varargin{j})
0088         case 'TrainingData'
0089             if length(varargin{j+1})==2
0090                 tiAll=varargin{j+1}{1};
0091                 tdAll=varargin{j+1}{2};
0092                 %Initialize
0093                 TrainingResults.DEDW=0*ThNN.Weights;
0094                 TrainingResults.DEDTau=TrainingResults.DEDW;
0095                 TrainingResults.DEDWOld=TrainingResults.DEDW;
0096                 TrainingResults.DEDTauOld=TrainingResults.DEDW;
0097                 TrainingResults.NonFireFlag=0;
0098                 TrainingResults.NonFireCount=[];
0099                 DT=datestr(clock);
0100                 DT(DT=='-' | DT==' ' | DT==':')='_';                
0101                 TrainingResults.DT=DT;
0102                 %Get Initial RMSE Using Updated get_error function
0103                 [TrainingResults.SSE, TrainingResults.RMSE, TrainingResults.ClassErr]=get_error(ThNN,tiAll,tdAll,2);
0104             elseif length(varargin{j+1})==3
0105                 tiAll=varargin{j+1}{1};
0106                 tdAll=varargin{j+1}{2};
0107                 TrainingResults=varargin{j+1}{3};
0108             else
0109                 %Problem
0110                 TrainingResults=-1; TestingResults=-1;                
0111                 disp('Error in inputs to train function');
0112                 return;
0113             end
0114             NumPatterns=length(tiAll);
0115         case 'TestingData'
0116             if length(varargin{j+1})==2
0117                 tiAllTest=varargin{j+1}{1};
0118                 tdAllTest=varargin{j+1}{2};
0119                 %Initialize
0120                 TestingResults.DEDW=0*ThNN.Weights;
0121                 TestingResults.DEDTau=TrainingResults.DEDW;
0122                 TestingResults.DEDWOld=TrainingResults.DEDW;
0123                 TestingResults.DEDTauOld=TrainingResults.DEDW;
0124                 TestingResults.NonFireFlag=0;
0125                 TestingResults.NonFireCount=[];
0126                 %Get Initial RMSE Using Updated get_error function
0127                 [TestingResults.SSE, TestingResults.RMSE, TestingResults.ClassErr]=get_error(ThNN,tiAllTest,tdAllTest,2);
0128             elseif length(varargin{j+1})==3
0129                 tiAllTest=varargin{j+1}{1};
0130                 tdAllTest=varargin{j+1}{2};
0131                 TestingResults=varargin{j+1}{3};
0132             else
0133                 %Problem
0134                 TrainingResults=-1; TestingResults=-1;
0135                 disp('Error in inputs to train function');
0136                 return;
0137             end
0138             NumTests=length(tiAllTest);            
0139         otherwise
0140             %Problem
0141             TrainingResults=-1; TestingResults=-1;
0142             disp('Error in inputs to train function');
0143             return;
0144     end
0145 end
0146 
0147 %Get Classes If Needed, as well as some Size Verification
0148 if strcmp(TrainingParams.Type,'Classification')
0149     Classes=[];
0150     for k=1:length(tdAll)
0151         for m=1:size(tdAll{k},1)
0152             Classes=push_unique(Classes,tdAll{k}(:,2)');
0153         end
0154     end
0155     Classes=sort(Classes);
0156     OutputNeurons=get_output_neurons(ThNN);
0157     if length(OutputNeurons)>1
0158         disp('Problem A!');
0159     end    
0160 end
0161 
0162 %Create Save Information
0163 C=filesep;
0164 SaveDir=[pwd,C,'Results',C,TrainingParams.DataName,C,TrainingParams.DataName,'_',num2str(TrainingResults.DT),C];
0165 if ~isdir(SaveDir)
0166     mkdir(SaveDir)
0167 end
0168 
0169 %GUI Hook: Flags to allow GUI to control the function flow
0170 global Flags;
0171 
0172 %Hack for Old Method Hooks
0173 NumInputs=length(ThNN.InputNeurons);
0174 NumOutputs=length(ThNN.OutputNeurons);
0175 NumNeurons=length(ThNN.Neurons);
0176 %End HACK
0177 InitialEpoch=length(TrainingResults.SSE);
0178 
0179 %GUI Hook: Set the status display
0180 if TrainingParams.FromGUI
0181     set(TrainingParams.Handles.Current_Epoch,'String',num2str(InitialEpoch));
0182     set(TrainingParams.Handles.MSE,'String',num2str(TrainingResults.RMSE(end)));
0183     set(TrainingParams.Handles.Current_Learning_Rate,'String',num2str(TrainingParams.WeightLearningRate));
0184 end
0185 
0186 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0187 %Begin Training
0188 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0189 for j=InitialEpoch:TrainingParams.NumEpochs
0190  
0191     %Pause breifly to allow this process to be interupted by flags
0192     pause(0.01); 
0193  
0194     %GUI Hook: Process a stop flag
0195     if TrainingParams.FromGUI && Flags.StopFlag
0196         Flags.StopFlag=0;
0197         Flags.TrainFlag=0;
0198         set(TrainingParams.Handles.Simulation_Status,'String',['Stopped at epoch ' num2str(j)]);
0199         pause(2);
0200         set(TrainingParams.Handles.Simulation_Status,'String','Waiting...');
0201         SaveFile=[SaveDir, TrainingParams.DataName,'_',num2str(TrainingResults.DT),'_',num2str(j)];
0202         if NumTests>0
0203             save(SaveFile,'ThNN','TrainingParams','TrainingResults','TestingResults','tiOrder','tiAll','tiAllTest','tdAll','tdAllTest');
0204         else
0205             save(SaveFile,'ThNN','TrainingParams','TrainingResults','tiOrder','tiAll','tdAll');
0206         end
0207         set(TrainingParams.Handles.Continue,'UserData',SaveFile);
0208         enable_network_parameters(TrainingParams.Handles,'on');
0209         return;
0210     end
0211     
0212     %GUI Hook: Perturb weights if desired
0213     if TrainingParams.FromGUI && Flags.PerturbFlag
0214         Flags.PerturbFlag=0;
0215         SimStatus=get(Handles.Simulation_Status,'String');
0216         set(TrainingParams.Handles.Simulation_Status,'String',['Perturbing Weights by ' num2str(TrainingParams.WeightPerturbPercent) '%']);
0217         ThNN.Weights=ThNN.Weights.*(1+0.01*TrainingParams.WeightPerturbPercent*rand(size(ThNN.Weights)));
0218         pause(2);
0219         set(TrainingParams.Handles.Simulation_Status,'String',SimStatus);
0220     end
0221     
0222     %Train a single epoch
0223     [ThNN, TrainingResults, tiOrder]=train_epoch(ThNN,TrainingParams,tiAll,tdAll,TrainingResults,Verbose);
0224     
0225     %Break if output neurons are not firing
0226     if size(TrainingResults.DEDW,1)==1 && TrainingResults.DEDW==-1
0227         disp(['Breaking at epoch ' num2str(j)]);
0228         if TrainingParams.FromGUI
0229             set(TrainingParams.Handles.Simulation_Status,'String','Output Neuron(s) Not Firing!');
0230             pause(4);
0231             set(TrainingParams.Handles.Simulation_Status,'String','Waiting...');
0232             enable_network_parameters(TrainingParams.Handles,'on');
0233         end
0234         return;
0235     end
0236     
0237     %GUI Hook: Numerical Display of Current Results
0238     if TrainingParams.FromGUI && (j==1 || mod(j,TrainingParams.DisplayFrequency)==0)
0239         set(TrainingParams.Handles.Current_Epoch,'String',num2str(j));
0240         set(TrainingParams.Handles.MSE,'String',num2str(TrainingResults.RMSE(j)));
0241         set(TrainingParams.Handles.Current_Learning_Rate,'String',num2str(TrainingParams.WeightLearningRate));
0242         drawnow;
0243     end
0244         
0245     %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0246     %Show results on the first epoch and at some frequency thereafter and
0247     %save results
0248     %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0249     if (j==1 || (NumTests==0 && mod(j,TrainingParams.DisplayFrequency)==0) || (NumTests~=0 && mod(j,TrainingParams.TestingFrequency)==0))
0250         disp(' ');
0251         disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%');
0252         disp(['EPOCH ', num2str(j)]); 
0253         disp('Training Results:');
0254 
0255         TrainingIndex=max(1,j-TrainingParams.DisplayFrequency+1);
0256         disp(['RMSE: ',num2str(TrainingResults.RMSE(end)),...
0257             ' with a change of ', num2str(TrainingResults.RMSE(end)-TrainingResults.RMSE(TrainingIndex)),...
0258             ' over ', num2str(min(j,(TrainingParams.DisplayFrequency))), ' epochs']);
0259         disp(['SSE: ',num2str(TrainingResults.SSE(end)),...
0260             ' with a change of ', num2str(TrainingResults.SSE(end)-TrainingResults.SSE(TrainingIndex)),...
0261             ' over ', num2str(min(j,(TrainingParams.DisplayFrequency))), ' epochs']);        
0262         
0263         %Hack: Old Display Method
0264         for k=1:min(10,length(tdAll))
0265             if isscalar(tiOrder)
0266                 Index=1;
0267             else
0268                 Index=find(tiOrder==k);%tiOrder(k);
0269             end
0270             
0271             for p=1:NumOutputs
0272                 tsDisplay{p}{Index}=TrainingResults.tsCurrent{Index}{NumNeurons-NumOutputs+p};
0273                 tdDisplay{p}{Index}=tdAll{k}(find(tdAll{k}(:,1)==(NumNeurons-NumOutputs+p)),2)';
0274                 %If not Classification, display all pattern results
0275                 %If Classification, display only results which
0276                 %deviate by at least 2 ms from desired spike times
0277                 if ~strcmp(TrainingParams.Type,'Classification') || abs(tdDisplay{p}{Index}-tsDisplay{p}{Index})>=2
0278                     disp(['Neuron: ', num2str(NumNeurons-NumOutputs+p), ' Desired: ', num2str(tdDisplay{p}{Index}), ' Actual: ', num2str(tsDisplay{p}{Index}),' Diff: ', num2str(tdDisplay{p}{Index}-tsDisplay{p}{Index})]);
0279                 end
0280                 if length(tdDisplay{p}{Index})~=length(tsDisplay{p}{Index})
0281                     disp('Error: The number of desired and actual output firing times do not match!');
0282                     disp('Either an output neuron is not firing or their is an error!');
0283                     return;
0284                 end
0285             end
0286         end
0287         %End Hack
0288         
0289         if strcmp(TrainingParams.Type,'Classification')
0290             %Find Class that is closest
0291             ClassErr=0;
0292             for k=1:NumPatterns
0293                 if isscalar(tiOrder)
0294                     Index=1;
0295                 else
0296                     Index=find(tiOrder==k);
0297                 end
0298                 
0299                 if length(TrainingResults.tsCurrent{Index}{OutputNeurons(1)})>1
0300                     %Maybe could do MIMO pattern matching for
0301                     %classification though?
0302                     disp('Problem B!');
0303                 end        
0304                 for m=1:length(Classes)
0305                     ClassDistance=abs(TrainingResults.tsCurrent{Index}{OutputNeurons(1)}(1)-Classes(m));                    
0306                     if m>1 && ClassDistance>ClassDistanceOld 
0307                         CurrentClass=Classes(m-1);
0308                         break;
0309                     end
0310                     CurrentClass=Classes(m); %Match the final class
0311                     ClassDistanceOld=ClassDistance;
0312                 end
0313                 %Desired
0314                 if CurrentClass~=tdAll{k}(1,2)
0315                     ClassErr=ClassErr+1;
0316                 end   
0317             end
0318             ClassErr=ClassErr/NumPatterns;
0319             TrainingResults.ClassErr(end+1)=ClassErr;
0320             disp(['Classification Error: ', num2str(100*ClassErr), '%']);
0321         end
0322         
0323 
0324         %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0325         %Show Test Results
0326         %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0327         if (NumTests>0)
0328             disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%');
0329             disp('Testing Results:');
0330             [TestingResults.SSE(end+1),TestingResults.RMSE(end+1),TestingResults.ClassErr(end+1),TestingResults.tsCurrent]=get_error(ThNN,tiAllTest,tdAllTest,2);
0331             
0332             %Hack: Old Display Method
0333             try
0334                 TestingIndex=max(1,floor((j/TrainingParams.TestingFrequency)+1));
0335                 disp(['RMSE: ',num2str(TestingResults.RMSE(end)),...
0336                     ' with a change of ', num2str(TestingResults.RMSE(end)-TestingResults.RMSE(TestingIndex)),...
0337                     ' over ', num2str(min(j,(TrainingParams.TestingFrequency))), ' epochs']);
0338                 disp(['SSE: ',num2str(TestingResults.SSE(end)),...
0339                     ' with a change of ', num2str(TestingResults.SSE(end)-TestingResults.SSE(TestingIndex)),...
0340                     ' over ', num2str(min(j,(TrainingParams.TestingFrequency))), ' epochs']);
0341                 %----------
0342                 for k=1:min(10,length(tdAllTest))
0343                     for p=1:NumOutputs
0344                         tsDisplayTest{p}{k}=TestingResults.tsCurrent{k}{NumNeurons-NumOutputs+p};
0345                         tdDisplayTest{p}{k}=tdAllTest{k}(find(tdAllTest{k}(:,1)==(NumNeurons-NumOutputs+p)),2)';
0346                         %If not Classification, display all pattern results
0347                         %If Classification, display only results which
0348                         %deviate by at least 2 ms from desired spike times
0349                         if ~strcmp(TrainingParams.Type,'Classification') || abs(tdDisplayTest{p}{k}-tsDisplayTest{p}{k})>=2
0350                             disp(['Neuron: ', num2str(NumNeurons-NumOutputs+p), ' Desired: ', num2str(tdDisplayTest{p}{k}), ' Actual: ', num2str(tsDisplayTest{p}{k}),' Diff: ', num2str(tdDisplayTest{p}{k}-tsDisplayTest{p}{k})]);
0351                         end
0352                         if length(tdDisplayTest{p}{k})~=length(tsDisplayTest{p}{k})
0353                             disp('Warning: Not all tests produced output spikes!');
0354                             disp('Testing Error is not valid during this epoch!');
0355                             %                        return;
0356                         end
0357                     end
0358                 end
0359                 %----------
0360                 disp(['Classification Error: ', num2str(100*TestingResults.ClassErr(end)), '%']);
0361             catch
0362                 disp('Not all tests produced output spikes! Testing Error is not valid during this epoch!');
0363             end
0364             %End Hack
0365 
0366 
0367         end
0368 
0369         %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0370         %GUI Hook: Plot Training, Testing and Status Results & Make Movie
0371         %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0372         if TrainingParams.FromGUI
0373             if NumTests>0
0374                 plot_rmse(TrainingParams,TrainingResults,TestingResults);
0375                 plot_extra(TrainingParams,tiOrder,tiAll,TrainingResults,tiAllTest,TestingResults);
0376             else
0377                 plot_rmse(TrainingParams,TrainingResults);
0378                 plot_extra(TrainingParams,tiOrder,tiAll,TrainingResults);
0379             end
0380             if strfind(TrainingParams.Type,'Classification')
0381                 set(TrainingParams.Handles.CrispCE,'String',num2str(TrainingResults.ClassErr(end)));
0382             else
0383                 set(TrainingParams.Handles.CrispCE,'String','N/A');
0384             end
0385             drawnow;
0386 
0387 
0388             if TrainingParams.MakeMovie==1
0389                  if isfield(TrainingParams,'M')
0390                      TrainingResults.M{1}(end+1) = getframe(TrainingParams.Handles.axes4); %rmse window
0391                      TrainingResults.M{2}(end+1) = getframe(TrainingParams.Handles.axes1); %extra plot window
0392                  else
0393                      TrainingResults.M{1}(1) = getframe(TrainingParams.Handles.axes4);
0394                      TrainingResults.M{2}(1) = getframe(TrainingParams.Handles.axes1);
0395                  end
0396             end
0397         end
0398 
0399         %Warn user that some hidden neurons are not firing
0400         if TrainingResults.NonFireFlag
0401             disp('Warning! One or more neurons are not firing!');
0402             disp('The count of non-firing instances by neuron index for an "epoch":');
0403             TrainingResults.NonFireCount
0404             disp(' ');
0405         end
0406         disp(' ');
0407 
0408         %Save everything
0409         disp('Saving File...');
0410         SaveFile=[SaveDir, TrainingParams.DataName,'_',num2str(TrainingResults.DT),'_',num2str(j)];
0411         if NumTests>0
0412             save(SaveFile,'ThNN','TrainingParams','TrainingResults','TestingResults','tiOrder','tiAll','tiAllTest','tdAll','tdAllTest');
0413         else
0414             save(SaveFile,'ThNN','TrainingParams','TrainingResults','tiOrder','tiAll','tdAll');
0415         end
0416         disp('Done Saving File!');
0417     end
0418 
0419     %Check Stopping Criteria
0420     if (TrainingResults.RMSE(end)<TrainingParams.MaxError)
0421         disp(['Desired error acheived!  Training has completed in ', num2str(j), ' epochs']);
0422         break;
0423     end
0424     
0425 end
0426 
0427 
0428 %GUI Hook: End of Simulation Results Display
0429 if TrainingParams.FromGUI
0430     if NumTests>0
0431         plot_rmse(TrainingParams,TrainingResults,TestingResults);
0432         plot_extra(TrainingParams,tiOrder,tiAll,TrainingResults,tiAllTest,TestingResults);
0433     else
0434         plot_rmse(TrainingParams,TrainingResults);
0435         plot_extra(TrainingParams,tiOrder,tiAll,TrainingResults);
0436     end
0437 end
0438 
0439 %End of Simulation Save
0440 SaveFile=[SaveDir, TrainingParams.DataName,'_',num2str(TrainingResults.DT),'_',num2str(j)];
0441 if NumTests>0
0442     save(SaveFile,'ThNN','TrainingParams','TrainingResults','TestingResults','tiOrder','tiAll','tiAllTest','tdAll','tdAllTest');
0443 else
0444     save(SaveFile,'ThNN','TrainingParams','TrainingResults','tiOrder','tiAll','tdAll');
0445 end
0446 
0447 
0448 %GUI Hook: End of Simulation GUI Clean-up
0449 if TrainingParams.FromGUI    
0450     set(TrainingParams.Handles.Continue,'UserData',SaveFile);
0451     set(TrainingParams.Handles.Simulation_Status,'String',['Training completed at epoch ' num2str(j-1)]);
0452     pause(5);
0453     Flags.TrainFlag=0;
0454     set(TrainingParams.Handles.Simulation_Status,'String','Waiting...');
0455     enable_network_parameters(TrainingParams.Handles,'on');
0456 end
0457 
0458 return;

Generated on Wed 02-Apr-2008 15:16:32 by m2html © 2003