Home > TNNT_1_07 > @theta_neuron_network > train_epoch.m

train_epoch

PURPOSE ^

TRAIN_EPOCH trains a theta neuron network for a single epoch

SYNOPSIS ^

function [ThNN, TrainingResults, tiOrder]=train_epoch(ThNN,TrainingParams,tiAll,tdAll,TrainingResults,Verbose)

DESCRIPTION ^

TRAIN_EPOCH trains a theta neuron network for a single epoch

Description:
Function to train a theta neuron network for a single epoch.  An epoch is
the presentation of all input patterns once.  Depending on the learning
method, the flow of the training will differ.

Syntax:
ThNN=TRAIN_EPOCH(ThNN,TrainingParams,tiAll,tdAll)
ThNN=TRAIN_EPOCH(...,TrainingResults,[Verbose])
[ThNN, TrainingResults, tiOrder]=TRAIN_EPOCH(...)

Input Parameters:
o ThNN: An object of the theta neuron network class
o TrainingParams: A structure that contains information for training a
    theta neuron network, such as the learning method. This structure is
    generated by get_training_params.
o tiAll: A cell array of length a, the number of input patterns, of which
    each cell is a qx2 array that contains the neuron indices for each 
    spike time and the input spike times.  q may vary from cell to cell.
o tdAll: A cell array of length a, the number of input patterns, of which 
    each cell is a (qo x 2) array that contains the output neuron indices 
    for each desired output spike time and the desired output spike times.
    qo may vary from cell to cell.
o TrainingResults: A structure containing error gradient with respect to
    the weights (DEDW) and delays (DEDTau).  The structure also includes
    the SSE, NonFireFlag to indicate is any neurons are not firing along
    with NonFireCount, an array used to keep track of how many times a 
    neuron has not fired over multiple input patterns.  This structure is
    generated by calculate_gradient.
o Verbose: Optional flag to indicate if extra information will be
    displayed on the screen. A value of 0 displays no additional
    information (this is the default value), while a value of 1 displays
    all information.  Values greater than 1 display partial information.
    The default value is 0. See Verbose for more details.

Output Parameters:
o ThNN: An object of the theta neuron network class, now with weights and
    delays updated by gradient-based results in TrainingResults.
o TrainingResults: A structure containing error gradient with respect to
    the weights (DEDW) and delays (DEDTau).  The structure also includes
    the SSE, NonFireFlag to indicate is any neurons are not firing along
    with NonFireCount, an array used to keep track of how many times a 
    neuron has not fired over multiple input patterns.  This structure is
    generated by calculate_gradient. If passed as an input, results are
    appended.
o tiOrder: The order in which the input patterns were processed during this
    epoch.  The order is random when using online learning.

Example:
>> %This example is a simple inverter trained over 500 epochs
>> ThNN=theta_neuron_network;
>> TrainingParams=get_training_params;
>> [ThNN, TrainingResults]=train_epoch(ThNN,TrainingParams,...
     {[2 3],[2 6]},{[3 25],[3 20]});
>> TrainingResults.RMSE(end)
>> for j=1:500
     [ThNN, TrainingResults]=train_epoch(ThNN,TrainingParams,...
       {[2 3],[2 6]},{[3 25],[3 20]},TrainingResults);
   end
>> ts = run_network(ThNN, [2 3])
>> ts = run_network(ThNN, [2 6])
>> TrainingResults.RMSE(end)
>> figure;
>> plot(TrainingResults.RMSE);
>> xlabel('Epochs'); ylabel('RMSE');

See also theta_neuron_network, verbose

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [ThNN, TrainingResults, tiOrder]=train_epoch(ThNN,TrainingParams,tiAll,tdAll,TrainingResults,Verbose)
0002 %TRAIN_EPOCH trains a theta neuron network for a single epoch
0003 %
0004 %Description:
0005 %Function to train a theta neuron network for a single epoch.  An epoch is
0006 %the presentation of all input patterns once.  Depending on the learning
0007 %method, the flow of the training will differ.
0008 %
0009 %Syntax:
0010 %ThNN=TRAIN_EPOCH(ThNN,TrainingParams,tiAll,tdAll)
0011 %ThNN=TRAIN_EPOCH(...,TrainingResults,[Verbose])
0012 %[ThNN, TrainingResults, tiOrder]=TRAIN_EPOCH(...)
0013 %
0014 %Input Parameters:
0015 %o ThNN: An object of the theta neuron network class
0016 %o TrainingParams: A structure that contains information for training a
0017 %    theta neuron network, such as the learning method. This structure is
0018 %    generated by get_training_params.
0019 %o tiAll: A cell array of length a, the number of input patterns, of which
0020 %    each cell is a qx2 array that contains the neuron indices for each
0021 %    spike time and the input spike times.  q may vary from cell to cell.
0022 %o tdAll: A cell array of length a, the number of input patterns, of which
0023 %    each cell is a (qo x 2) array that contains the output neuron indices
0024 %    for each desired output spike time and the desired output spike times.
0025 %    qo may vary from cell to cell.
0026 %o TrainingResults: A structure containing error gradient with respect to
0027 %    the weights (DEDW) and delays (DEDTau).  The structure also includes
0028 %    the SSE, NonFireFlag to indicate is any neurons are not firing along
0029 %    with NonFireCount, an array used to keep track of how many times a
0030 %    neuron has not fired over multiple input patterns.  This structure is
0031 %    generated by calculate_gradient.
0032 %o Verbose: Optional flag to indicate if extra information will be
0033 %    displayed on the screen. A value of 0 displays no additional
0034 %    information (this is the default value), while a value of 1 displays
0035 %    all information.  Values greater than 1 display partial information.
0036 %    The default value is 0. See Verbose for more details.
0037 %
0038 %Output Parameters:
0039 %o ThNN: An object of the theta neuron network class, now with weights and
0040 %    delays updated by gradient-based results in TrainingResults.
0041 %o TrainingResults: A structure containing error gradient with respect to
0042 %    the weights (DEDW) and delays (DEDTau).  The structure also includes
0043 %    the SSE, NonFireFlag to indicate is any neurons are not firing along
0044 %    with NonFireCount, an array used to keep track of how many times a
0045 %    neuron has not fired over multiple input patterns.  This structure is
0046 %    generated by calculate_gradient. If passed as an input, results are
0047 %    appended.
0048 %o tiOrder: The order in which the input patterns were processed during this
0049 %    epoch.  The order is random when using online learning.
0050 %
0051 %Example:
0052 %>> %This example is a simple inverter trained over 500 epochs
0053 %>> ThNN=theta_neuron_network;
0054 %>> TrainingParams=get_training_params;
0055 %>> [ThNN, TrainingResults]=train_epoch(ThNN,TrainingParams,...
0056 %     {[2 3],[2 6]},{[3 25],[3 20]});
0057 %>> TrainingResults.RMSE(end)
0058 %>> for j=1:500
0059 %     [ThNN, TrainingResults]=train_epoch(ThNN,TrainingParams,...
0060 %       {[2 3],[2 6]},{[3 25],[3 20]},TrainingResults);
0061 %   end
0062 %>> ts = run_network(ThNN, [2 3])
0063 %>> ts = run_network(ThNN, [2 6])
0064 %>> TrainingResults.RMSE(end)
0065 %>> figure;
0066 %>> plot(TrainingResults.RMSE);
0067 %>> xlabel('Epochs'); ylabel('RMSE');
0068 %
0069 %See also theta_neuron_network, verbose
0070 
0071 %Copyright (C) 2008 Sam McKennoch <Samuel.McKennoch@loria.fr>
0072 
0073 
0074 if nargin<4
0075     disp('Error in train_epoch: Not enough input arguements');
0076     disp(['Needed at least 4 inputs but only got ' num2str(nargin)]);
0077     tiOrder=-1;
0078     if nargin<5
0079         TrainingResults=-1;
0080     end
0081     return;
0082 end
0083 if nargin<5
0084     %Initialize
0085     TrainingResults.DEDW=0*ThNN.Weights;
0086     TrainingResults.DEDTau=TrainingResults.DEDW;
0087     TrainingResults.DEDWOld=TrainingResults.DEDW;
0088     TrainingResults.DEDTauOld=TrainingResults.DEDW;
0089     TrainingResults.RMSE=[];    
0090     TrainingResults.SSE=[];
0091     TrainingResults.tsCurrent=[];   
0092     DT=datestr(clock);
0093     DT(DT=='-' | DT==' ' | DT==':')='_';
0094     TrainingResults.DT=DT;    
0095 end    
0096 if nargin<6
0097     Verbose=0;
0098 end
0099 CurrentSSE=0;
0100 TrainingResults.NonFireFlag=0;
0101 TrainingResults.NonFireCount=zeros(1,size(ThNN.Weights,1));    
0102 
0103 
0104 %Use gradient learning based on NECO calculated gradient
0105 if (Verbose==1)
0106     disp('Training Using Updated Training Method!');
0107     disp('Initial Weights:');
0108     ThNN.Weights
0109 end
0110 
0111 
0112 %Determining Order of Input Pattern Presentation
0113 if ~isempty(strfind(lower(TrainingParams.LearningMethod),'online'))
0114     OnlineFlag=1;
0115     if size(tiAll,2)==1
0116         tiOrder=1;
0117     else
0118         %shuffle ti and td so network doesn't train on the order of inputs
0119         rand('state',sum(100*clock));
0120         tiOrder=randperm(size(tiAll,2));
0121     end
0122 else
0123     OnlineFlag=0;    
0124     if size(tiAll,2)==1
0125         tiOrder=1;
0126     else
0127         tiOrder=1:size(tiAll,2);
0128     end
0129 end
0130 
0131 NumSpikes=0;
0132 ts=cell(length(tiOrder),1);
0133 for k=1:length(tiOrder)
0134     ti=tiAll{tiOrder(k)};
0135     td=tdAll{tiOrder(k)};
0136     [CurrentTrainingResults, ts{k}]=calculate_gradient(ThNN,TrainingParams,ti,td,Verbose);
0137     
0138     if size(CurrentTrainingResults.DEDW,1)==1 && CurrentTrainingResults.DEDW==-1
0139         TrainingResults.DEDW=-1;
0140         return;
0141     end        
0142     
0143     %For MSE, average of number of input patterns and number of output spikes
0144     NumSpikes=NumSpikes+size(td,1);
0145     CurrentSSE=CurrentSSE+CurrentTrainingResults.SSE(end);
0146     TrainingResults.NonFireFlag=TrainingResults.NonFireFlag || CurrentTrainingResults.NonFireFlag; 
0147     TrainingResults.NonFireCount=TrainingResults.NonFireCount+CurrentTrainingResults.NonFireCount;
0148     TrainingResults.DEDW=TrainingResults.DEDW+CurrentTrainingResults.DEDW;
0149     TrainingResults.DEDTau=TrainingResults.DEDTau+CurrentTrainingResults.DEDTau;
0150     
0151     if OnlineFlag==1
0152         ThNN=apply_gradient(ThNN,TrainingParams,CurrentTrainingResults,Verbose);
0153     end   
0154     
0155 end
0156 TrainingResults.SSE(end+1)=CurrentSSE;
0157 TrainingResults.RMSE(end+1)=sqrt(2*TrainingResults.SSE(end)/NumSpikes);
0158 TrainingResults.tsCurrent=ts;
0159 
0160 if OnlineFlag==0
0161     %Clipping, do I still want to be able to do this?
0162     %DEDW=min(DEDW,50);
0163     %DEDW=max(DEDW,-50);
0164     ThNN=apply_gradient(ThNN,TrainingParams,TrainingResults,Verbose);
0165 end
0166 
0167 return;
0168 
0169 
0170

Generated on Wed 02-Apr-2008 15:16:32 by m2html © 2003