APPLY_GRADIENT applies gradient-based changes to network weights and delays Description: Function to calculate gradient terms of the change in error with respect to both weight and delay, as well as the current sum of squared errors. Syntax: ThNN=APPLY_GRADIENT(ThNN,TrainingParams,TrainingResults); ThNN=APPLY_GRADIENT(..., Verbose); Input Parameters: o ThNN: An object of the theta neuron network class o TrainingParams: A structure that contains information for training a theta neuron network, such as the learning method. This structure is generated by get_training_params. o TrainingResults: A structure containing error gradient with respect to the weights (DEDW) and delays (DEDTau). The structure also includes the SSE, NonFireFlag to indicate is any neurons are not firing along with NonFireCount, an array used to keep track of how many times a neuron has not fired over multiple input patterns. This structure is generated by calculate_gradient. o Verbose: Optional flag to indicate if extra information will be displayed on the screen. A value of 0 displays no additional information (this is the default value), while a value of 1 displays all information. Values greater than 1 display partial information. The default value is 0. See Verbose for more details. Output Parameters: o ThNN: An object of the theta neuron network class, now with weights and delays updated by gradient-based results in TrainingResults. Examples: >> ThNN=theta_neuron_network; >> TrainingParams=get_training_params; >> [TrainingResults, ts]=calculate_gradient(ThNN,TrainingParams,... [2 3],[3 25],1); >> SSE1=get_error(ThNN,ts,[3 25]) >> ThNN=apply_gradient(ThNN,TrainingParams,TrainingResults,Verbose); >> ts = run_network(ThNN, [2 3]); >> SSE2=get_error(ThNN,ts,[3 25]) See also theta_neuron_network, verbose
0001 function ThNN=apply_gradient(ThNN,TrainingParams,TrainingResults,Verbose) 0002 %APPLY_GRADIENT applies gradient-based changes to network weights and delays 0003 % 0004 %Description: 0005 %Function to calculate gradient terms of the change in error with respect 0006 %to both weight and delay, as well as the current sum of squared errors. 0007 % 0008 %Syntax: 0009 %ThNN=APPLY_GRADIENT(ThNN,TrainingParams,TrainingResults); 0010 %ThNN=APPLY_GRADIENT(..., Verbose); 0011 % 0012 %Input Parameters: 0013 %o ThNN: An object of the theta neuron network class 0014 %o TrainingParams: A structure that contains information for training a 0015 % theta neuron network, such as the learning method. This structure is 0016 % generated by get_training_params. 0017 %o TrainingResults: A structure containing error gradient with respect to 0018 % the weights (DEDW) and delays (DEDTau). The structure also includes 0019 % the SSE, NonFireFlag to indicate is any neurons are not firing along 0020 % with NonFireCount, an array used to keep track of how many times a 0021 % neuron has not fired over multiple input patterns. This structure is 0022 % generated by calculate_gradient. 0023 %o Verbose: Optional flag to indicate if extra information will be 0024 % displayed on the screen. A value of 0 displays no additional 0025 % information (this is the default value), while a value of 1 displays 0026 % all information. Values greater than 1 display partial information. 0027 % The default value is 0. See Verbose for more details. 0028 % 0029 %Output Parameters: 0030 %o ThNN: An object of the theta neuron network class, now with weights and 0031 % delays updated by gradient-based results in TrainingResults. 0032 % 0033 %Examples: 0034 %>> ThNN=theta_neuron_network; 0035 %>> TrainingParams=get_training_params; 0036 %>> [TrainingResults, ts]=calculate_gradient(ThNN,TrainingParams,... 0037 % [2 3],[3 25],1); 0038 %>> SSE1=get_error(ThNN,ts,[3 25]) 0039 %>> ThNN=apply_gradient(ThNN,TrainingParams,TrainingResults,Verbose); 0040 %>> ts = run_network(ThNN, [2 3]); 0041 %>> SSE2=get_error(ThNN,ts,[3 25]) 0042 % 0043 %See also theta_neuron_network, verbose 0044 0045 %Copyright (C) 2008 Sam McKennoch <Samuel.McKennoch@loria.fr> 0046 0047 0048 if nargin<3 0049 disp('Error in apply_gradient: Not enough input arguements'); 0050 disp(['Needed at least 3 inputs but only got ' num2str(nargin)]); 0051 return; 0052 end 0053 if nargin<4 0054 Verbose=0; 0055 end 0056 0057 Weights=ThNN.Weights; 0058 Delays=ThNN.Delays; 0059 0060 switch (TrainingParams.LearningMethod) 0061 case {'Batch Gradient Descent','Online Gradient Descent','Online Batch Gradient Descent','Online Annealed Gradient Descent'} 0062 if TrainingParams.DelayEnable==0 0063 Weights=Weights-(TrainingParams.WeightLearningRate)*TrainingResults.DEDW; 0064 else 0065 Weights=Weights-(TrainingParams.WeightLearningRate)*TrainingResults.DEDW; 0066 Delays=max(Delays-(TrainingParams.DelayLearningRate)*TrainingResults.DEDTau,TrainingParams.MinDelay); 0067 end 0068 0069 case {'Batch Gradient Descent with Momentum','Online Gradient Descent with Momentum'} 0070 Weights=Weights-(TrainingParams.LearningRate)*dE_dw+(TNN_params.momentum)*dE_dw_old; 0071 case ('Quickprop') 0072 0073 0074 case ('RPROP') 0075 0076 0077 otherwise 0078 disp(['Error in apply gradient: Invalid Training Method: ', TrainingParams.LearningMethod]); 0079 ThNN=-1; 0080 return; 0081 end 0082 if (Verbose==1) 0083 disp('Modified Weights:'); 0084 Weights 0085 disp('Modified Delays:'); 0086 Delays 0087 end 0088 ThNN.Weights=Weights; 0089 ThNN.Delays=Delays;