Home > TNNT_1_07 > @theta_neuron_network > train_dataset.m

train_dataset

PURPOSE ^

TRAIN_DATA trains a theta neuron network given dataset information

SYNOPSIS ^

function [ThNN, TrainingResults, TestingResults] = train_dataset(ThNN, TrainingParams, varargin)

DESCRIPTION ^

TRAIN_DATA trains a theta neuron network given dataset information

Description:
Function to train a theta neuron network over multiple epochs. Inputs are
given in terms of a Dataset structure that references a pre-created data
file. This function includes option interfacing with the training GUI
as well.

Syntax:
TRAIN_DATA(ThNN,TrainingParams);
TRAIN_DATA(...,[TrainingResults],[TestingData],[Verbose]);
[ThNN, TrainingResults, TestingResults] = TRAIN_DATA(...);

Input Parameters:
o ThNN: An object of the theta neuron network class
o TrainingParams: A structure that contains information for training a
    theta neuron network, such as the learning method. This structure is
    generated by get_training_params. It also includes fields Type and
    DataName that indicate which dataset will be trained.
o TrainingResults: A structure containing error gradient with respect to
    the weights (DEDW) and delays (DEDTau).  The structure also includes
    the SSE, NonFireFlag to indicate is any neurons are not firing along
    with NonFireCount, an array used to keep track of how many times a
    neuron has not fired over multiple input patterns.  This structure is
    generated by calculate_gradient. Passed as an input argument for
    appending results.
o TestingResults: A structure containing results generated on the testing
    data in the same format as TrainingResults.
o Verbose: Optional flag to indicate if extra information will be
    displayed on the screen. A value of 0 displays no additional
    information (this is the default value), while a value of 1 displays
    all information.  Values greater than 1 display partial information.
    The default value is 0. See Verbose for more details.

Output Parameters:
o ThNN: An object of the theta neuron network class, now with weights and
    delays updated by gradient-based results in TrainingResults.
o TrainingResults: A structure containing error gradient with respect to
    the weights (DEDW) and delays (DEDTau).  The structure also includes
    the SSE, NonFireFlag to indicate is any neurons are not firing along
    with NonFireCount, an array used to keep track of how many times a
    neuron has not fired over multiple input patterns.  This structure is
    generated by calculate_gradient. If passed as an input, certain
    results are appended.
o TestingResults: A structure containing results generated on the testing
    data in the same format as TrainingResults.

Examples:
>> %This example is a simple inverter trained to default RMSE
>> ThNN=theta_neuron_network;
>> TrainingParams=get_training_params;
>> TrainingParams.Type='SpikeTimes';
>> TrainingParams.Name='Inverter';
>> [ThNN, TrainingResults]=train_dataset(ThNN,TrainingParams);
>> TrainingResults.RMSE(end)
>> figure;
>> plot(TrainingResults.RMSE);
>> xlabel('Epochs'); ylabel('RMSE');

See also theta_neuron_network, verbose

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [ThNN, TrainingResults, TestingResults] = train_dataset(ThNN, TrainingParams, varargin)
0002 %TRAIN_DATA trains a theta neuron network given dataset information
0003 %
0004 %Description:
0005 %Function to train a theta neuron network over multiple epochs. Inputs are
0006 %given in terms of a Dataset structure that references a pre-created data
0007 %file. This function includes option interfacing with the training GUI
0008 %as well.
0009 %
0010 %Syntax:
0011 %TRAIN_DATA(ThNN,TrainingParams);
0012 %TRAIN_DATA(...,[TrainingResults],[TestingData],[Verbose]);
0013 %[ThNN, TrainingResults, TestingResults] = TRAIN_DATA(...);
0014 %
0015 %Input Parameters:
0016 %o ThNN: An object of the theta neuron network class
0017 %o TrainingParams: A structure that contains information for training a
0018 %    theta neuron network, such as the learning method. This structure is
0019 %    generated by get_training_params. It also includes fields Type and
0020 %    DataName that indicate which dataset will be trained.
0021 %o TrainingResults: A structure containing error gradient with respect to
0022 %    the weights (DEDW) and delays (DEDTau).  The structure also includes
0023 %    the SSE, NonFireFlag to indicate is any neurons are not firing along
0024 %    with NonFireCount, an array used to keep track of how many times a
0025 %    neuron has not fired over multiple input patterns.  This structure is
0026 %    generated by calculate_gradient. Passed as an input argument for
0027 %    appending results.
0028 %o TestingResults: A structure containing results generated on the testing
0029 %    data in the same format as TrainingResults.
0030 %o Verbose: Optional flag to indicate if extra information will be
0031 %    displayed on the screen. A value of 0 displays no additional
0032 %    information (this is the default value), while a value of 1 displays
0033 %    all information.  Values greater than 1 display partial information.
0034 %    The default value is 0. See Verbose for more details.
0035 %
0036 %Output Parameters:
0037 %o ThNN: An object of the theta neuron network class, now with weights and
0038 %    delays updated by gradient-based results in TrainingResults.
0039 %o TrainingResults: A structure containing error gradient with respect to
0040 %    the weights (DEDW) and delays (DEDTau).  The structure also includes
0041 %    the SSE, NonFireFlag to indicate is any neurons are not firing along
0042 %    with NonFireCount, an array used to keep track of how many times a
0043 %    neuron has not fired over multiple input patterns.  This structure is
0044 %    generated by calculate_gradient. If passed as an input, certain
0045 %    results are appended.
0046 %o TestingResults: A structure containing results generated on the testing
0047 %    data in the same format as TrainingResults.
0048 %
0049 %Examples:
0050 %>> %This example is a simple inverter trained to default RMSE
0051 %>> ThNN=theta_neuron_network;
0052 %>> TrainingParams=get_training_params;
0053 %>> TrainingParams.Type='SpikeTimes';
0054 %>> TrainingParams.Name='Inverter';
0055 %>> [ThNN, TrainingResults]=train_dataset(ThNN,TrainingParams);
0056 %>> TrainingResults.RMSE(end)
0057 %>> figure;
0058 %>> plot(TrainingResults.RMSE);
0059 %>> xlabel('Epochs'); ylabel('RMSE');
0060 %
0061 %See also theta_neuron_network, verbose
0062 
0063 %Copyright (C) 2008 Sam McKennoch <Samuel.McKennoch@loria.fr>
0064 
0065 
0066 if nargin == 5
0067     TrainingResults = varargin{1};
0068     TestingResults = varargin{2};
0069     Verbose = varargin{3};    
0070 elseif nargin == 4
0071     if ~isstruct(varargin{2})
0072         TrainingResults = varargin{1};
0073         TestingResults=[];
0074         Verbose = varargin{2};
0075     else
0076         TrainingResults = varargin{1};
0077         TestingResults= varargin{2};
0078         Verbose = 0;
0079     end
0080 elseif nargin == 3
0081     TestingResults=[];
0082     if ~isstruct(varargin{1})
0083         TrainingResults = [];
0084         Verbose = varargin{1};        
0085     else
0086         TrainingResults = varargin{1};
0087         Verbose = 0;
0088     end
0089 elseif nargin == 2
0090     TrainingResults = [];
0091     TestingResults=[];    
0092     Verbose = 0;
0093 else
0094     disp('Error in train_dataset: Inputs are not correct');
0095     TrainingResults=[];
0096     TestingResults=[];    
0097     return;
0098 end
0099 
0100 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0101 %Load Input Data
0102 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0103 C=filesep;
0104 try
0105     disp(['Loading ', 'Datasets', C, TrainingParams.Type, C, TrainingParams.DataName]);
0106     load(['Datasets', C, TrainingParams.Type, C, TrainingParams.DataName]);
0107 catch
0108     disp(['Error: ', TrainingParams.DataName, ' not found!']);
0109     return;
0110 end
0111 
0112 %Account for Reference Time
0113 CodeDataInputs=ThNN.InputNeurons;
0114 if ThNN.ReferenceTime~=-1
0115     CodeDataInputs=setdiff(CodeDataInputs,1);
0116 end
0117 
0118 %Create Coding Structure Based on Dataset
0119 %Format input and output data into spike times
0120 switch (TrainingParams.Type)
0121     case 'SpikeTimes'
0122         Coding.EncodeMethod='Spikes';
0123         Coding.DecodeMethod='Spikes';
0124         tiAll=code_data(CodeDataInputs,'EncodeInputs',InputSpikeTimes,Coding);
0125         tdAll=code_data(ThNN.OutputNeurons,'DecodeOutputs',DesiredSpikeTimes,Coding);
0126         if TestFlag        
0127             tiAllTest=code_data(CodeDataInputs,'EncodeInputs',TestingInputSpikeTimes,Coding);
0128             tdAllTest=code_data(ThNN.OutputNeurons,'DecodeOutputs',TestingDesiredSpikeTimes,Coding);
0129             %tiAllTest=TestingInputSpikeTimes;
0130             %tdAllTest=TestingDesiredSpikeTimes;
0131         end
0132         ErrorMode=0;
0133     case 'Classification'
0134         Coding.EncodeMethod=EncodeMethod;
0135         Coding.DecodeMethod=DecodeMethod;
0136         %All of these exist statements allow Classification to work with
0137         %any type of coding, although it is kind of a hack
0138         if exist('InputSpikeRange')
0139             Coding.InputSpikeRange=InputSpikeRange;
0140         end
0141         if exist('OutputSpikeRange')
0142             Coding.OutputSpikeRange=OutputSpikeRange;
0143         end
0144         if exist('InputRange')
0145             Coding.InputRange=InputRange;
0146         end
0147         if exist('OutputRange')
0148             Coding.OutputRange=OutputRange;        
0149         end
0150         if exist('InputSpikeTimes')
0151             Inputs=InputSpikeTimes;
0152         end
0153         if exist('DesiredSpikeTimes')
0154             Outputs=DesiredSpikeTimes;
0155         end
0156         if exist('TestingInputSpikeTimes')
0157             Inputs=TestingInputSpikeTimes;
0158         end
0159         if exist('TestingDesiredSpikeTimes')
0160             Outputs=TestingDesiredSpikeTimes;
0161         end
0162         tiAll=code_data(CodeDataInputs,'EncodeInputs',Inputs,Coding);
0163         tdAll=code_data(ThNN.OutputNeurons,'DecodeOutputs',Outputs,Coding);
0164         if TestFlag
0165             tiAllTest=code_data(CodeDataInputs,'EncodeInputs',TestingInputs,Coding);
0166             tdAllTest=code_data(ThNN.OutputNeurons,'DecodeOutputs',TestingOutputs,Coding);        
0167         end
0168         ErrorMode=2;   
0169     case 'Regression'
0170         Coding.EncodeMethod=EncodeMethod;
0171         Coding.DecodeMethod=DecodeMethod;
0172         Coding.InputSpikeRange=InputSpikeRange;
0173         Coding.OutputSpikeRange=OutputSpikeRange;
0174         Coding.InputRange=InputRange;
0175         Coding.OutputRange=OutputRange;
0176         TrainingParams.FunctionHandle=FH;
0177         tiAll=code_data(CodeDataInputs,'EncodeInputs',Inputs,Coding);
0178         tdAll=code_data(ThNN.OutputNeurons,'DecodeOutputs',Outputs,Coding);
0179         if TestFlag
0180             tiAllTest=code_data(CodeDataInputs,'EncodeInputs',TestingInputs,Coding);
0181             tdAllTest=code_data(ThNN.OutputNeurons,'DecodeOutputs',TestingOutputs,Coding);        
0182         end        
0183         ErrorMode=0;        
0184     otherwise
0185 end
0186 TrainingParams.Coding=Coding;
0187 
0188 %Initialize results structures if this data has not been passed as inputs
0189 if isempty(TrainingResults)
0190     %Initialize
0191     TrainingResults.DEDW=0*ThNN.Weights;
0192     TrainingResults.DEDTau=TrainingResults.DEDW;
0193     TrainingResults.DEDWOld=TrainingResults.DEDW;
0194     TrainingResults.DEDTauOld=TrainingResults.DEDW;
0195     TrainingResults.NonFireFlag=0;
0196     TrainingResults.NonFireCount=[];
0197     DT=datestr(clock);
0198     DT(DT=='-' | DT==' ' | DT==':')='_';
0199     TrainingResults.DT=DT;
0200     
0201     %Get Initial RMSE Using Updated get_error function
0202     [TrainingResults.SSE, TrainingResults.RMSE, TrainingResults.ClassErr]=get_error(ThNN,tiAll,tdAll,ErrorMode);
0203 
0204     if TestFlag && isempty(TestingResults)
0205         %Initialize
0206         TestingResults.DEDW=0*ThNN.Weights;
0207         TestingResults.DEDTau=TestingResults.DEDW;
0208         TestingResults.DEDWOld=TestingResults.DEDW;
0209         TestingResults.DEDTauOld=TestingResults.DEDW;
0210         TestingResults.NonFireFlag=0;
0211         TestingResults.NonFireCount=[];
0212         %Get Initial RMSE Using Updated get_error function
0213         [TestingResults.SSE, TestingResults.RMSE, TestingResults.ClassErr]=get_error(ThNN,tiAllTest,tdAllTest,ErrorMode);
0214     end
0215 end
0216 
0217 
0218 if TestFlag
0219     [ThNN, TrainingResults, TestingResults] = train(ThNN, TrainingParams, 'TrainingData', {tiAll, tdAll, TrainingResults},...
0220         'TestingData', {tiAllTest, tdAllTest, TestingResults},Verbose);
0221 else
0222     [ThNN, TrainingResults] = train(ThNN, TrainingParams, 'TrainingData', {tiAll, tdAll, TrainingResults}, Verbose);
0223 end
0224 
0225 return;

Generated on Wed 02-Apr-2008 15:16:32 by m2html © 2003