TRAIN_EPOCH trains a theta neuron network for a single epoch Description: Function to train a theta neuron network for a single epoch. An epoch is the presentation of all input patterns once. Depending on the learning method, the flow of the training will differ. Syntax: ThNN=TRAIN_EPOCH(ThNN,TrainingParams,tiAll,tdAll) ThNN=TRAIN_EPOCH(...,TrainingResults,[Verbose]) [ThNN, TrainingResults, tiOrder]=TRAIN_EPOCH(...) Input Parameters: o ThNN: An object of the theta neuron network class o TrainingParams: A structure that contains information for training a theta neuron network, such as the learning method. This structure is generated by get_training_params. o tiAll: A cell array of length a, the number of input patterns, of which each cell is a qx2 array that contains the neuron indices for each spike time and the input spike times. q may vary from cell to cell. o tdAll: A cell array of length a, the number of input patterns, of which each cell is a (qo x 2) array that contains the output neuron indices for each desired output spike time and the desired output spike times. qo may vary from cell to cell. o TrainingResults: A structure containing error gradient with respect to the weights (DEDW) and delays (DEDTau). The structure also includes the SSE, NonFireFlag to indicate is any neurons are not firing along with NonFireCount, an array used to keep track of how many times a neuron has not fired over multiple input patterns. This structure is generated by calculate_gradient. o Verbose: Optional flag to indicate if extra information will be displayed on the screen. A value of 0 displays no additional information (this is the default value), while a value of 1 displays all information. Values greater than 1 display partial information. The default value is 0. See Verbose for more details. Output Parameters: o ThNN: An object of the theta neuron network class, now with weights and delays updated by gradient-based results in TrainingResults. o TrainingResults: A structure containing error gradient with respect to the weights (DEDW) and delays (DEDTau). The structure also includes the SSE, NonFireFlag to indicate is any neurons are not firing along with NonFireCount, an array used to keep track of how many times a neuron has not fired over multiple input patterns. This structure is generated by calculate_gradient. If passed as an input, results are appended. o tiOrder: The order in which the input patterns were processed during this epoch. The order is random when using online learning. Example: >> %This example is a simple inverter trained over 500 epochs >> ThNN=theta_neuron_network; >> TrainingParams=get_training_params; >> [ThNN, TrainingResults]=train_epoch(ThNN,TrainingParams,... {[2 3],[2 6]},{[3 25],[3 20]}); >> TrainingResults.RMSE(end) >> for j=1:500 [ThNN, TrainingResults]=train_epoch(ThNN,TrainingParams,... {[2 3],[2 6]},{[3 25],[3 20]},TrainingResults); end >> ts = run_network(ThNN, [2 3]) >> ts = run_network(ThNN, [2 6]) >> TrainingResults.RMSE(end) >> figure; >> plot(TrainingResults.RMSE); >> xlabel('Epochs'); ylabel('RMSE'); See also theta_neuron_network, verbose
0001 function [ThNN, TrainingResults, tiOrder]=train_epoch(ThNN,TrainingParams,tiAll,tdAll,TrainingResults,Verbose) 0002 %TRAIN_EPOCH trains a theta neuron network for a single epoch 0003 % 0004 %Description: 0005 %Function to train a theta neuron network for a single epoch. An epoch is 0006 %the presentation of all input patterns once. Depending on the learning 0007 %method, the flow of the training will differ. 0008 % 0009 %Syntax: 0010 %ThNN=TRAIN_EPOCH(ThNN,TrainingParams,tiAll,tdAll) 0011 %ThNN=TRAIN_EPOCH(...,TrainingResults,[Verbose]) 0012 %[ThNN, TrainingResults, tiOrder]=TRAIN_EPOCH(...) 0013 % 0014 %Input Parameters: 0015 %o ThNN: An object of the theta neuron network class 0016 %o TrainingParams: A structure that contains information for training a 0017 % theta neuron network, such as the learning method. This structure is 0018 % generated by get_training_params. 0019 %o tiAll: A cell array of length a, the number of input patterns, of which 0020 % each cell is a qx2 array that contains the neuron indices for each 0021 % spike time and the input spike times. q may vary from cell to cell. 0022 %o tdAll: A cell array of length a, the number of input patterns, of which 0023 % each cell is a (qo x 2) array that contains the output neuron indices 0024 % for each desired output spike time and the desired output spike times. 0025 % qo may vary from cell to cell. 0026 %o TrainingResults: A structure containing error gradient with respect to 0027 % the weights (DEDW) and delays (DEDTau). The structure also includes 0028 % the SSE, NonFireFlag to indicate is any neurons are not firing along 0029 % with NonFireCount, an array used to keep track of how many times a 0030 % neuron has not fired over multiple input patterns. This structure is 0031 % generated by calculate_gradient. 0032 %o Verbose: Optional flag to indicate if extra information will be 0033 % displayed on the screen. A value of 0 displays no additional 0034 % information (this is the default value), while a value of 1 displays 0035 % all information. Values greater than 1 display partial information. 0036 % The default value is 0. See Verbose for more details. 0037 % 0038 %Output Parameters: 0039 %o ThNN: An object of the theta neuron network class, now with weights and 0040 % delays updated by gradient-based results in TrainingResults. 0041 %o TrainingResults: A structure containing error gradient with respect to 0042 % the weights (DEDW) and delays (DEDTau). The structure also includes 0043 % the SSE, NonFireFlag to indicate is any neurons are not firing along 0044 % with NonFireCount, an array used to keep track of how many times a 0045 % neuron has not fired over multiple input patterns. This structure is 0046 % generated by calculate_gradient. If passed as an input, results are 0047 % appended. 0048 %o tiOrder: The order in which the input patterns were processed during this 0049 % epoch. The order is random when using online learning. 0050 % 0051 %Example: 0052 %>> %This example is a simple inverter trained over 500 epochs 0053 %>> ThNN=theta_neuron_network; 0054 %>> TrainingParams=get_training_params; 0055 %>> [ThNN, TrainingResults]=train_epoch(ThNN,TrainingParams,... 0056 % {[2 3],[2 6]},{[3 25],[3 20]}); 0057 %>> TrainingResults.RMSE(end) 0058 %>> for j=1:500 0059 % [ThNN, TrainingResults]=train_epoch(ThNN,TrainingParams,... 0060 % {[2 3],[2 6]},{[3 25],[3 20]},TrainingResults); 0061 % end 0062 %>> ts = run_network(ThNN, [2 3]) 0063 %>> ts = run_network(ThNN, [2 6]) 0064 %>> TrainingResults.RMSE(end) 0065 %>> figure; 0066 %>> plot(TrainingResults.RMSE); 0067 %>> xlabel('Epochs'); ylabel('RMSE'); 0068 % 0069 %See also theta_neuron_network, verbose 0070 0071 %Copyright (C) 2008 Sam McKennoch <Samuel.McKennoch@loria.fr> 0072 0073 0074 if nargin<4 0075 disp('Error in train_epoch: Not enough input arguements'); 0076 disp(['Needed at least 4 inputs but only got ' num2str(nargin)]); 0077 tiOrder=-1; 0078 if nargin<5 0079 TrainingResults=-1; 0080 end 0081 return; 0082 end 0083 if nargin<5 0084 %Initialize 0085 TrainingResults.DEDW=0*ThNN.Weights; 0086 TrainingResults.DEDTau=TrainingResults.DEDW; 0087 TrainingResults.DEDWOld=TrainingResults.DEDW; 0088 TrainingResults.DEDTauOld=TrainingResults.DEDW; 0089 TrainingResults.RMSE=[]; 0090 TrainingResults.SSE=[]; 0091 TrainingResults.tsCurrent=[]; 0092 DT=datestr(clock); 0093 DT(DT=='-' | DT==' ' | DT==':')='_'; 0094 TrainingResults.DT=DT; 0095 end 0096 if nargin<6 0097 Verbose=0; 0098 end 0099 CurrentSSE=0; 0100 TrainingResults.NonFireFlag=0; 0101 TrainingResults.NonFireCount=zeros(1,size(ThNN.Weights,1)); 0102 0103 0104 %Use gradient learning based on NECO calculated gradient 0105 if (Verbose==1) 0106 disp('Training Using Updated Training Method!'); 0107 disp('Initial Weights:'); 0108 ThNN.Weights 0109 end 0110 0111 0112 %Determining Order of Input Pattern Presentation 0113 if ~isempty(strfind(lower(TrainingParams.LearningMethod),'online')) 0114 OnlineFlag=1; 0115 if size(tiAll,2)==1 0116 tiOrder=1; 0117 else 0118 %shuffle ti and td so network doesn't train on the order of inputs 0119 rand('state',sum(100*clock)); 0120 tiOrder=randperm(size(tiAll,2)); 0121 end 0122 else 0123 OnlineFlag=0; 0124 if size(tiAll,2)==1 0125 tiOrder=1; 0126 else 0127 tiOrder=1:size(tiAll,2); 0128 end 0129 end 0130 0131 NumSpikes=0; 0132 ts=cell(length(tiOrder),1); 0133 for k=1:length(tiOrder) 0134 ti=tiAll{tiOrder(k)}; 0135 td=tdAll{tiOrder(k)}; 0136 [CurrentTrainingResults, ts{k}]=calculate_gradient(ThNN,TrainingParams,ti,td,Verbose); 0137 0138 if size(CurrentTrainingResults.DEDW,1)==1 && CurrentTrainingResults.DEDW==-1 0139 TrainingResults.DEDW=-1; 0140 return; 0141 end 0142 0143 %For MSE, average of number of input patterns and number of output spikes 0144 NumSpikes=NumSpikes+size(td,1); 0145 CurrentSSE=CurrentSSE+CurrentTrainingResults.SSE(end); 0146 TrainingResults.NonFireFlag=TrainingResults.NonFireFlag || CurrentTrainingResults.NonFireFlag; 0147 TrainingResults.NonFireCount=TrainingResults.NonFireCount+CurrentTrainingResults.NonFireCount; 0148 TrainingResults.DEDW=TrainingResults.DEDW+CurrentTrainingResults.DEDW; 0149 TrainingResults.DEDTau=TrainingResults.DEDTau+CurrentTrainingResults.DEDTau; 0150 0151 if OnlineFlag==1 0152 ThNN=apply_gradient(ThNN,TrainingParams,CurrentTrainingResults,Verbose); 0153 end 0154 0155 end 0156 TrainingResults.SSE(end+1)=CurrentSSE; 0157 TrainingResults.RMSE(end+1)=sqrt(2*TrainingResults.SSE(end)/NumSpikes); 0158 TrainingResults.tsCurrent=ts; 0159 0160 if OnlineFlag==0 0161 %Clipping, do I still want to be able to do this? 0162 %DEDW=min(DEDW,50); 0163 %DEDW=max(DEDW,-50); 0164 ThNN=apply_gradient(ThNN,TrainingParams,TrainingResults,Verbose); 0165 end 0166 0167 return; 0168 0169 0170