TRAIN_DATA trains a theta neuron network given dataset information Description: Function to train a theta neuron network over multiple epochs. Inputs are given in terms of a Dataset structure that references a pre-created data file. This function includes option interfacing with the training GUI as well. Syntax: TRAIN_DATA(ThNN,TrainingParams); TRAIN_DATA(...,[TrainingResults],[TestingData],[Verbose]); [ThNN, TrainingResults, TestingResults] = TRAIN_DATA(...); Input Parameters: o ThNN: An object of the theta neuron network class o TrainingParams: A structure that contains information for training a theta neuron network, such as the learning method. This structure is generated by get_training_params. It also includes fields Type and DataName that indicate which dataset will be trained. o TrainingResults: A structure containing error gradient with respect to the weights (DEDW) and delays (DEDTau). The structure also includes the SSE, NonFireFlag to indicate is any neurons are not firing along with NonFireCount, an array used to keep track of how many times a neuron has not fired over multiple input patterns. This structure is generated by calculate_gradient. Passed as an input argument for appending results. o TestingResults: A structure containing results generated on the testing data in the same format as TrainingResults. o Verbose: Optional flag to indicate if extra information will be displayed on the screen. A value of 0 displays no additional information (this is the default value), while a value of 1 displays all information. Values greater than 1 display partial information. The default value is 0. See Verbose for more details. Output Parameters: o ThNN: An object of the theta neuron network class, now with weights and delays updated by gradient-based results in TrainingResults. o TrainingResults: A structure containing error gradient with respect to the weights (DEDW) and delays (DEDTau). The structure also includes the SSE, NonFireFlag to indicate is any neurons are not firing along with NonFireCount, an array used to keep track of how many times a neuron has not fired over multiple input patterns. This structure is generated by calculate_gradient. If passed as an input, certain results are appended. o TestingResults: A structure containing results generated on the testing data in the same format as TrainingResults. Examples: >> %This example is a simple inverter trained to default RMSE >> ThNN=theta_neuron_network; >> TrainingParams=get_training_params; >> TrainingParams.Type='SpikeTimes'; >> TrainingParams.Name='Inverter'; >> [ThNN, TrainingResults]=train_dataset(ThNN,TrainingParams); >> TrainingResults.RMSE(end) >> figure; >> plot(TrainingResults.RMSE); >> xlabel('Epochs'); ylabel('RMSE'); See also theta_neuron_network, verbose
0001 function [ThNN, TrainingResults, TestingResults] = train_dataset(ThNN, TrainingParams, varargin) 0002 %TRAIN_DATA trains a theta neuron network given dataset information 0003 % 0004 %Description: 0005 %Function to train a theta neuron network over multiple epochs. Inputs are 0006 %given in terms of a Dataset structure that references a pre-created data 0007 %file. This function includes option interfacing with the training GUI 0008 %as well. 0009 % 0010 %Syntax: 0011 %TRAIN_DATA(ThNN,TrainingParams); 0012 %TRAIN_DATA(...,[TrainingResults],[TestingData],[Verbose]); 0013 %[ThNN, TrainingResults, TestingResults] = TRAIN_DATA(...); 0014 % 0015 %Input Parameters: 0016 %o ThNN: An object of the theta neuron network class 0017 %o TrainingParams: A structure that contains information for training a 0018 % theta neuron network, such as the learning method. This structure is 0019 % generated by get_training_params. It also includes fields Type and 0020 % DataName that indicate which dataset will be trained. 0021 %o TrainingResults: A structure containing error gradient with respect to 0022 % the weights (DEDW) and delays (DEDTau). The structure also includes 0023 % the SSE, NonFireFlag to indicate is any neurons are not firing along 0024 % with NonFireCount, an array used to keep track of how many times a 0025 % neuron has not fired over multiple input patterns. This structure is 0026 % generated by calculate_gradient. Passed as an input argument for 0027 % appending results. 0028 %o TestingResults: A structure containing results generated on the testing 0029 % data in the same format as TrainingResults. 0030 %o Verbose: Optional flag to indicate if extra information will be 0031 % displayed on the screen. A value of 0 displays no additional 0032 % information (this is the default value), while a value of 1 displays 0033 % all information. Values greater than 1 display partial information. 0034 % The default value is 0. See Verbose for more details. 0035 % 0036 %Output Parameters: 0037 %o ThNN: An object of the theta neuron network class, now with weights and 0038 % delays updated by gradient-based results in TrainingResults. 0039 %o TrainingResults: A structure containing error gradient with respect to 0040 % the weights (DEDW) and delays (DEDTau). The structure also includes 0041 % the SSE, NonFireFlag to indicate is any neurons are not firing along 0042 % with NonFireCount, an array used to keep track of how many times a 0043 % neuron has not fired over multiple input patterns. This structure is 0044 % generated by calculate_gradient. If passed as an input, certain 0045 % results are appended. 0046 %o TestingResults: A structure containing results generated on the testing 0047 % data in the same format as TrainingResults. 0048 % 0049 %Examples: 0050 %>> %This example is a simple inverter trained to default RMSE 0051 %>> ThNN=theta_neuron_network; 0052 %>> TrainingParams=get_training_params; 0053 %>> TrainingParams.Type='SpikeTimes'; 0054 %>> TrainingParams.Name='Inverter'; 0055 %>> [ThNN, TrainingResults]=train_dataset(ThNN,TrainingParams); 0056 %>> TrainingResults.RMSE(end) 0057 %>> figure; 0058 %>> plot(TrainingResults.RMSE); 0059 %>> xlabel('Epochs'); ylabel('RMSE'); 0060 % 0061 %See also theta_neuron_network, verbose 0062 0063 %Copyright (C) 2008 Sam McKennoch <Samuel.McKennoch@loria.fr> 0064 0065 0066 if nargin == 5 0067 TrainingResults = varargin{1}; 0068 TestingResults = varargin{2}; 0069 Verbose = varargin{3}; 0070 elseif nargin == 4 0071 if ~isstruct(varargin{2}) 0072 TrainingResults = varargin{1}; 0073 TestingResults=[]; 0074 Verbose = varargin{2}; 0075 else 0076 TrainingResults = varargin{1}; 0077 TestingResults= varargin{2}; 0078 Verbose = 0; 0079 end 0080 elseif nargin == 3 0081 TestingResults=[]; 0082 if ~isstruct(varargin{1}) 0083 TrainingResults = []; 0084 Verbose = varargin{1}; 0085 else 0086 TrainingResults = varargin{1}; 0087 Verbose = 0; 0088 end 0089 elseif nargin == 2 0090 TrainingResults = []; 0091 TestingResults=[]; 0092 Verbose = 0; 0093 else 0094 disp('Error in train_dataset: Inputs are not correct'); 0095 TrainingResults=[]; 0096 TestingResults=[]; 0097 return; 0098 end 0099 0100 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 0101 %Load Input Data 0102 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 0103 C=filesep; 0104 try 0105 disp(['Loading ', 'Datasets', C, TrainingParams.Type, C, TrainingParams.DataName]); 0106 load(['Datasets', C, TrainingParams.Type, C, TrainingParams.DataName]); 0107 catch 0108 disp(['Error: ', TrainingParams.DataName, ' not found!']); 0109 return; 0110 end 0111 0112 %Account for Reference Time 0113 CodeDataInputs=ThNN.InputNeurons; 0114 if ThNN.ReferenceTime~=-1 0115 CodeDataInputs=setdiff(CodeDataInputs,1); 0116 end 0117 0118 %Create Coding Structure Based on Dataset 0119 %Format input and output data into spike times 0120 switch (TrainingParams.Type) 0121 case 'SpikeTimes' 0122 Coding.EncodeMethod='Spikes'; 0123 Coding.DecodeMethod='Spikes'; 0124 tiAll=code_data(CodeDataInputs,'EncodeInputs',InputSpikeTimes,Coding); 0125 tdAll=code_data(ThNN.OutputNeurons,'DecodeOutputs',DesiredSpikeTimes,Coding); 0126 if TestFlag 0127 tiAllTest=code_data(CodeDataInputs,'EncodeInputs',TestingInputSpikeTimes,Coding); 0128 tdAllTest=code_data(ThNN.OutputNeurons,'DecodeOutputs',TestingDesiredSpikeTimes,Coding); 0129 %tiAllTest=TestingInputSpikeTimes; 0130 %tdAllTest=TestingDesiredSpikeTimes; 0131 end 0132 ErrorMode=0; 0133 case 'Classification' 0134 Coding.EncodeMethod=EncodeMethod; 0135 Coding.DecodeMethod=DecodeMethod; 0136 %All of these exist statements allow Classification to work with 0137 %any type of coding, although it is kind of a hack 0138 if exist('InputSpikeRange') 0139 Coding.InputSpikeRange=InputSpikeRange; 0140 end 0141 if exist('OutputSpikeRange') 0142 Coding.OutputSpikeRange=OutputSpikeRange; 0143 end 0144 if exist('InputRange') 0145 Coding.InputRange=InputRange; 0146 end 0147 if exist('OutputRange') 0148 Coding.OutputRange=OutputRange; 0149 end 0150 if exist('InputSpikeTimes') 0151 Inputs=InputSpikeTimes; 0152 end 0153 if exist('DesiredSpikeTimes') 0154 Outputs=DesiredSpikeTimes; 0155 end 0156 if exist('TestingInputSpikeTimes') 0157 Inputs=TestingInputSpikeTimes; 0158 end 0159 if exist('TestingDesiredSpikeTimes') 0160 Outputs=TestingDesiredSpikeTimes; 0161 end 0162 tiAll=code_data(CodeDataInputs,'EncodeInputs',Inputs,Coding); 0163 tdAll=code_data(ThNN.OutputNeurons,'DecodeOutputs',Outputs,Coding); 0164 if TestFlag 0165 tiAllTest=code_data(CodeDataInputs,'EncodeInputs',TestingInputs,Coding); 0166 tdAllTest=code_data(ThNN.OutputNeurons,'DecodeOutputs',TestingOutputs,Coding); 0167 end 0168 ErrorMode=2; 0169 case 'Regression' 0170 Coding.EncodeMethod=EncodeMethod; 0171 Coding.DecodeMethod=DecodeMethod; 0172 Coding.InputSpikeRange=InputSpikeRange; 0173 Coding.OutputSpikeRange=OutputSpikeRange; 0174 Coding.InputRange=InputRange; 0175 Coding.OutputRange=OutputRange; 0176 TrainingParams.FunctionHandle=FH; 0177 tiAll=code_data(CodeDataInputs,'EncodeInputs',Inputs,Coding); 0178 tdAll=code_data(ThNN.OutputNeurons,'DecodeOutputs',Outputs,Coding); 0179 if TestFlag 0180 tiAllTest=code_data(CodeDataInputs,'EncodeInputs',TestingInputs,Coding); 0181 tdAllTest=code_data(ThNN.OutputNeurons,'DecodeOutputs',TestingOutputs,Coding); 0182 end 0183 ErrorMode=0; 0184 otherwise 0185 end 0186 TrainingParams.Coding=Coding; 0187 0188 %Initialize results structures if this data has not been passed as inputs 0189 if isempty(TrainingResults) 0190 %Initialize 0191 TrainingResults.DEDW=0*ThNN.Weights; 0192 TrainingResults.DEDTau=TrainingResults.DEDW; 0193 TrainingResults.DEDWOld=TrainingResults.DEDW; 0194 TrainingResults.DEDTauOld=TrainingResults.DEDW; 0195 TrainingResults.NonFireFlag=0; 0196 TrainingResults.NonFireCount=[]; 0197 DT=datestr(clock); 0198 DT(DT=='-' | DT==' ' | DT==':')='_'; 0199 TrainingResults.DT=DT; 0200 0201 %Get Initial RMSE Using Updated get_error function 0202 [TrainingResults.SSE, TrainingResults.RMSE, TrainingResults.ClassErr]=get_error(ThNN,tiAll,tdAll,ErrorMode); 0203 0204 if TestFlag && isempty(TestingResults) 0205 %Initialize 0206 TestingResults.DEDW=0*ThNN.Weights; 0207 TestingResults.DEDTau=TestingResults.DEDW; 0208 TestingResults.DEDWOld=TestingResults.DEDW; 0209 TestingResults.DEDTauOld=TestingResults.DEDW; 0210 TestingResults.NonFireFlag=0; 0211 TestingResults.NonFireCount=[]; 0212 %Get Initial RMSE Using Updated get_error function 0213 [TestingResults.SSE, TestingResults.RMSE, TestingResults.ClassErr]=get_error(ThNN,tiAllTest,tdAllTest,ErrorMode); 0214 end 0215 end 0216 0217 0218 if TestFlag 0219 [ThNN, TrainingResults, TestingResults] = train(ThNN, TrainingParams, 'TrainingData', {tiAll, tdAll, TrainingResults},... 0220 'TestingData', {tiAllTest, tdAllTest, TestingResults},Verbose); 0221 else 0222 [ThNN, TrainingResults] = train(ThNN, TrainingParams, 'TrainingData', {tiAll, tdAll, TrainingResults}, Verbose); 0223 end 0224 0225 return;