Home > TNNT_1_07 > @theta_neuron_network > apply_gradient.m

apply_gradient

PURPOSE ^

APPLY_GRADIENT applies gradient-based changes to network weights and delays

SYNOPSIS ^

function ThNN=apply_gradient(ThNN,TrainingParams,TrainingResults,Verbose)

DESCRIPTION ^

APPLY_GRADIENT applies gradient-based changes to network weights and delays

Description:
Function to calculate gradient terms of the change in error with respect
to both weight and delay, as well as the current sum of squared errors.

Syntax:
ThNN=APPLY_GRADIENT(ThNN,TrainingParams,TrainingResults);
ThNN=APPLY_GRADIENT(..., Verbose);

Input Parameters:
o ThNN: An object of the theta neuron network class
o TrainingParams: A structure that contains information for training a
    theta neuron network, such as the learning method. This structure is
    generated by get_training_params.
o TrainingResults: A structure containing error gradient with respect to
    the weights (DEDW) and delays (DEDTau).  The structure also includes
    the SSE, NonFireFlag to indicate is any neurons are not firing along
    with NonFireCount, an array used to keep track of how many times a 
    neuron has not fired over multiple input patterns.  This structure is
    generated by calculate_gradient.
o Verbose: Optional flag to indicate if extra information will be
    displayed on the screen. A value of 0 displays no additional
    information (this is the default value), while a value of 1 displays
    all information.  Values greater than 1 display partial information.
    The default value is 0. See Verbose for more details.

Output Parameters:
o ThNN: An object of the theta neuron network class, now with weights and
    delays updated by gradient-based results in TrainingResults.

Examples:
>> ThNN=theta_neuron_network;
>> TrainingParams=get_training_params;
>> [TrainingResults, ts]=calculate_gradient(ThNN,TrainingParams,...
     [2 3],[3 25],1);
>> SSE1=get_error(ThNN,ts,[3 25])
>> ThNN=apply_gradient(ThNN,TrainingParams,TrainingResults,Verbose);
>> ts = run_network(ThNN, [2 3]);
>> SSE2=get_error(ThNN,ts,[3 25])

See also theta_neuron_network, verbose

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function ThNN=apply_gradient(ThNN,TrainingParams,TrainingResults,Verbose)
0002 %APPLY_GRADIENT applies gradient-based changes to network weights and delays
0003 %
0004 %Description:
0005 %Function to calculate gradient terms of the change in error with respect
0006 %to both weight and delay, as well as the current sum of squared errors.
0007 %
0008 %Syntax:
0009 %ThNN=APPLY_GRADIENT(ThNN,TrainingParams,TrainingResults);
0010 %ThNN=APPLY_GRADIENT(..., Verbose);
0011 %
0012 %Input Parameters:
0013 %o ThNN: An object of the theta neuron network class
0014 %o TrainingParams: A structure that contains information for training a
0015 %    theta neuron network, such as the learning method. This structure is
0016 %    generated by get_training_params.
0017 %o TrainingResults: A structure containing error gradient with respect to
0018 %    the weights (DEDW) and delays (DEDTau).  The structure also includes
0019 %    the SSE, NonFireFlag to indicate is any neurons are not firing along
0020 %    with NonFireCount, an array used to keep track of how many times a
0021 %    neuron has not fired over multiple input patterns.  This structure is
0022 %    generated by calculate_gradient.
0023 %o Verbose: Optional flag to indicate if extra information will be
0024 %    displayed on the screen. A value of 0 displays no additional
0025 %    information (this is the default value), while a value of 1 displays
0026 %    all information.  Values greater than 1 display partial information.
0027 %    The default value is 0. See Verbose for more details.
0028 %
0029 %Output Parameters:
0030 %o ThNN: An object of the theta neuron network class, now with weights and
0031 %    delays updated by gradient-based results in TrainingResults.
0032 %
0033 %Examples:
0034 %>> ThNN=theta_neuron_network;
0035 %>> TrainingParams=get_training_params;
0036 %>> [TrainingResults, ts]=calculate_gradient(ThNN,TrainingParams,...
0037 %     [2 3],[3 25],1);
0038 %>> SSE1=get_error(ThNN,ts,[3 25])
0039 %>> ThNN=apply_gradient(ThNN,TrainingParams,TrainingResults,Verbose);
0040 %>> ts = run_network(ThNN, [2 3]);
0041 %>> SSE2=get_error(ThNN,ts,[3 25])
0042 %
0043 %See also theta_neuron_network, verbose
0044 
0045 %Copyright (C) 2008 Sam McKennoch <Samuel.McKennoch@loria.fr>
0046 
0047 
0048 if nargin<3
0049     disp('Error in apply_gradient: Not enough input arguements');
0050     disp(['Needed at least 3 inputs but only got ' num2str(nargin)]);
0051     return;
0052 end
0053 if nargin<4
0054     Verbose=0;
0055 end
0056 
0057 Weights=ThNN.Weights;
0058 Delays=ThNN.Delays;
0059 
0060 switch (TrainingParams.LearningMethod)
0061     case {'Batch Gradient Descent','Online Gradient Descent','Online Batch Gradient Descent','Online Annealed Gradient Descent'}
0062         if TrainingParams.DelayEnable==0
0063             Weights=Weights-(TrainingParams.WeightLearningRate)*TrainingResults.DEDW;        
0064         else
0065             Weights=Weights-(TrainingParams.WeightLearningRate)*TrainingResults.DEDW;
0066             Delays=max(Delays-(TrainingParams.DelayLearningRate)*TrainingResults.DEDTau,TrainingParams.MinDelay);
0067         end
0068         
0069     case {'Batch Gradient Descent with Momentum','Online Gradient Descent with Momentum'}
0070         Weights=Weights-(TrainingParams.LearningRate)*dE_dw+(TNN_params.momentum)*dE_dw_old;
0071     case ('Quickprop')
0072         
0073 
0074     case ('RPROP')
0075 
0076 
0077     otherwise
0078         disp(['Error in apply gradient: Invalid Training Method: ', TrainingParams.LearningMethod]);
0079         ThNN=-1;
0080         return;
0081 end
0082 if (Verbose==1)
0083     disp('Modified Weights:');
0084     Weights
0085     disp('Modified Delays:');
0086     Delays
0087 end
0088 ThNN.Weights=Weights;
0089 ThNN.Delays=Delays;

Generated on Wed 02-Apr-2008 15:16:32 by m2html © 2003