Home > TNNT_1_07 > @theta_neuron_network > calculate_gradient.m

calculate_gradient

PURPOSE ^

CALCULATE_GRADIENT returns array that relates spiking and synaptic indices

SYNOPSIS ^

function [TrainingResults, ts]=calculate_gradient(ThNN,TrainingParams,tiCurrent,tdCurrent,Verbose)

DESCRIPTION ^

CALCULATE_GRADIENT returns array that relates spiking and synaptic indices

Description:
Function to calculate gradient terms of the change in error with respect
to both weight and delay, as well as the current sum of squared errors.

Syntax:
TrainingResults=CALCULATE_GRADIENT(ThNN,TrainingParams,tiCurrent,tdCurrent);
TrainingResults=CALCULATE_GRADIENT(..., Verbose);
[TrainingResults, ts]=CALCULATE_GRADIENT(...);

Input Parameters:
o ThNN: An object of the theta neuron network class
o TrainingParams: A structure that contains information for training a
    theta neuron network, such as the learning rates. This structure is
    generated by get_training_params.
o tiCurrent: A qx2 array that contains the neuron indices (globally indexed to
    the network) for each spike time (column 1) and the input spike times
    (column 2). ti may be empty, which in most cases will result in a
    non-firing network.
o tdCurrent: An rx2 array that contains the neuron indices (globally indexed to
    the network) for each spike time (column 1) and the desired output
    spike times (column 2).
o Verbose: Optional flag to indicate if extra information will be
    displayed on the screen. A value of 0 displays no additional
    information (this is the default value), while a value of 1 displays
    all information.  Values greater than 1 display partial information.
    The default value is 0. See Verbose for more details.

Output Parameters:
o TrainingResults: A structure containing error gradient with respect to
    the weights (DEDW) and delays (DEDTau).  The structure also includes
    the SSE, NonFireFlag to indicate is any neurons are not firing along
    with NonFireCount, an array used to keep track of how many times a 
    neuron has not fired over multiple input patterns.
o ts: A cell array of length NumNeurons containing in each cell an array
    of length r_j of neuron j's output spike times.

Examples:
>> ThNN=theta_neuron_network;
>> TrainingParams=get_training_params;
>> [TrainingResults, ts]=calculate_gradient(ThNN,TrainingParams,...
     [2 3],[3 25],1)

See also theta_neuron_network, verbose

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [TrainingResults, ts]=calculate_gradient(ThNN,TrainingParams,tiCurrent,tdCurrent,Verbose)
0002 %CALCULATE_GRADIENT returns array that relates spiking and synaptic indices
0003 %
0004 %Description:
0005 %Function to calculate gradient terms of the change in error with respect
0006 %to both weight and delay, as well as the current sum of squared errors.
0007 %
0008 %Syntax:
0009 %TrainingResults=CALCULATE_GRADIENT(ThNN,TrainingParams,tiCurrent,tdCurrent);
0010 %TrainingResults=CALCULATE_GRADIENT(..., Verbose);
0011 %[TrainingResults, ts]=CALCULATE_GRADIENT(...);
0012 %
0013 %Input Parameters:
0014 %o ThNN: An object of the theta neuron network class
0015 %o TrainingParams: A structure that contains information for training a
0016 %    theta neuron network, such as the learning rates. This structure is
0017 %    generated by get_training_params.
0018 %o tiCurrent: A qx2 array that contains the neuron indices (globally indexed to
0019 %    the network) for each spike time (column 1) and the input spike times
0020 %    (column 2). ti may be empty, which in most cases will result in a
0021 %    non-firing network.
0022 %o tdCurrent: An rx2 array that contains the neuron indices (globally indexed to
0023 %    the network) for each spike time (column 1) and the desired output
0024 %    spike times (column 2).
0025 %o Verbose: Optional flag to indicate if extra information will be
0026 %    displayed on the screen. A value of 0 displays no additional
0027 %    information (this is the default value), while a value of 1 displays
0028 %    all information.  Values greater than 1 display partial information.
0029 %    The default value is 0. See Verbose for more details.
0030 %
0031 %Output Parameters:
0032 %o TrainingResults: A structure containing error gradient with respect to
0033 %    the weights (DEDW) and delays (DEDTau).  The structure also includes
0034 %    the SSE, NonFireFlag to indicate is any neurons are not firing along
0035 %    with NonFireCount, an array used to keep track of how many times a
0036 %    neuron has not fired over multiple input patterns.
0037 %o ts: A cell array of length NumNeurons containing in each cell an array
0038 %    of length r_j of neuron j's output spike times.
0039 %
0040 %Examples:
0041 %>> ThNN=theta_neuron_network;
0042 %>> TrainingParams=get_training_params;
0043 %>> [TrainingResults, ts]=calculate_gradient(ThNN,TrainingParams,...
0044 %     [2 3],[3 25],1)
0045 %
0046 %See also theta_neuron_network, verbose
0047 
0048 %Copyright (C) 2008 Sam McKennoch <Samuel.McKennoch@loria.fr>
0049 
0050 
0051 TrainingResults.DEDW=-1; TrainingResults.DEDTau=-1; TrainingResults.SSE=-1;
0052 ts=-1;
0053 
0054 if nargin<4
0055     disp('Error in calculate_gradient: Not enough input arguements');
0056     disp(['Needed at least 4 inputs but only got ' num2str(nargin)]);
0057     return;
0058 end
0059 if nargin<5
0060     Verbose=0;
0061 end
0062 
0063 if (Verbose==1)
0064     disp('*** Entering calculate_gradient ***');
0065     disp(' ');
0066 end
0067 NumericalGradient=TrainingParams.NumericalGradient;
0068 GradientStep=TrainingParams.GradientStep;
0069 NoiseLevel=TrainingParams.NoiseLevel;
0070 DEDW=zeros(size(ThNN.Weights));
0071 DEDTau=zeros(size(ThNN.Delays));
0072 tiCurrent(:,2)=tiCurrent(:,2)+NoiseLevel*(1+randn(size(tiCurrent,1),1));
0073 %Inline sortrows
0074 %tdCurrent=sortrows(tdCurrent,1);
0075 [Temp1,Temp2]=sort(tdCurrent(:,2));
0076 tdCurrent=[tdCurrent(Temp2,1) Temp1];
0077 
0078 %Get Static Array of Relative Input and Output Neurons
0079 RelativeInputNeurons=ThNN.RelativeInputNeurons;
0080 RelativeOutputNeurons=ThNN.RelativeOutputNeurons;
0081 
0082 GradFlag=1;
0083 MaxSimTime=1000;
0084 [ts, Y, Z] = run_network(ThNN,tiCurrent,MaxSimTime,GradFlag,Verbose);
0085 K2QAll=relate_q_to_k(ThNN,ts);
0086 
0087 %Generate List of neurons that need to be processed
0088 NeuronQueue=get_output_neurons(ThNN);
0089 
0090 %Get output spike times and check for non-firing
0091 TrainingResults.NonFireFlag=0;
0092 TrainingResults.NonFireCount=zeros(1,max(size(ThNN.Neurons)));
0093 for j=1:max(size(ThNN.Neurons))
0094     if isempty(ts{j})
0095         TrainingResults.NonFireFlag=1;
0096         TrainingResults.NonFireCount(j)=1;
0097     end
0098 end
0099 
0100 if sum(TrainingResults.NonFireCount(NeuronQueue)>0)
0101     TrainingParams
0102     TrainingResults
0103     disp(['Error! Neurons (including output neurons) with the following indices are not firing: ', num2str(find(TrainingResults.NonFireCount>0))]);
0104     TrainingResults.DEDW=-1;
0105     TrainingResults.DEDTau=-1;
0106     return;
0107 end
0108 
0109 %calculate dels for output layer
0110 del=cell(length(ThNN.Neurons),1);
0111 for j=1:length(NeuronQueue)
0112     CurrentNeuron=NeuronQueue(j);
0113     %Get List of desired output spike times for this neuron
0114     td=sort(tdCurrent(find(tdCurrent(:,1)==CurrentNeuron),2));
0115     %Adjust desired and actual output spike time vector lengths if possible
0116     if length(td)==length(ts{CurrentNeuron})
0117         del{CurrentNeuron}=ts{CurrentNeuron}-td';
0118     elseif length(td)<length(ts{CurrentNeuron})
0119         ts{CurrentNeuron}=ts{CurrentNeuron}(1:length(td));
0120         del{CurrentNeuron}=ts{CurrentNeuron}(1:length(td))-td';
0121     elseif length(ts{CurrentNeuron})<length(td) && ThNN.Neurons(CurrentNeuron).Io>0
0122         for k=(length(ts{CurrentNeuron})+1):length(td)
0123             ts{CurrentNeuron}(end+1)=ts{CurrentNeuron}(end)+(pi/ThNN.Neurons(CurrentNeuron).Beta);
0124         end
0125         del{CurrentNeuron}=ts{CurrentNeuron}-td';
0126     else %output layer Inot<0 and we are not producing enough output spikes,
0127         %later I could implement a hack to try to increase the number of spikes being produced
0128         disp(['Error! Output Layer Io<0 and there are not enough output spikes being generated',...
0129             ' by output neuron ', num2str(CurrentNeuron)]);
0130         return;
0131     end
0132 end
0133 
0134 %Calculate SSE and
0135 %Fill in extra Z and Y values if there were any cyclic spikes added on
0136 %Or remove extra Z and Y values if there were any cyclic spikes removed
0137 CurrentSSE=0;
0138 for p=1:length(NeuronQueue)
0139     CurrentNeuron=NeuronQueue(p);
0140     CurrentSSE=CurrentSSE+0.5*sum((del{CurrentNeuron}).^2);
0141     if size(Z{CurrentNeuron},2)<size(del{CurrentNeuron},2)
0142         for j=(size(Z{CurrentNeuron},2)+1):size(del{CurrentNeuron},2)
0143             Z{CurrentNeuron}(:,j)=Z{CurrentNeuron}(:,end);
0144             Y{CurrentNeuron}(:,j)=Y{CurrentNeuron}(:,end);
0145         end
0146     end
0147     if size(Z{CurrentNeuron},2)>size(del{CurrentNeuron},2)
0148         Z{CurrentNeuron}=Z{CurrentNeuron}(:,1:size(del{CurrentNeuron},2));
0149         Y{CurrentNeuron}=Y{CurrentNeuron}(:,1:size(del{CurrentNeuron},2));
0150     end
0151 
0152 end
0153 for p=1:length(NeuronQueue)
0154     CurrentNeuron=NeuronQueue(p);
0155 
0156     CurrentInputNeurons=RelativeInputNeurons{CurrentNeuron};
0157     %CurrentInputNeurons=get_input_neurons(ThNN,CurrentNeuron);
0158     for m=1:length(CurrentInputNeurons)
0159         K2Q=K2QAll{CurrentInputNeurons(m),CurrentNeuron};
0160         if NumericalGradient
0161             DEDW(CurrentInputNeurons(m),CurrentNeuron)=numerical_gradient(ThNN,...
0162                 CurrentInputNeurons(m),CurrentNeuron,tiCurrent,tdCurrent,GradientStep,'Weights');
0163         else
0164             %Inline dot
0165             %DEDW(CurrentInputNeurons(m),CurrentNeuron)=dot(del{CurrentNeuron},Y{CurrentNeuron}(CurrentInputNeurons(m),:));
0166             DEDW(CurrentInputNeurons(m),CurrentNeuron)=sum(del{CurrentNeuron}.*Y{CurrentNeuron}(CurrentInputNeurons(m),:));
0167         end
0168         if NumericalGradient && TrainingParams.DelayEnable
0169             DEDTau(CurrentInputNeurons(m),CurrentNeuron)=numerical_gradient(ThNN,...
0170                 CurrentInputNeurons(m),CurrentNeuron,tiCurrent,tdCurrent,GradientStep,'Delays');
0171         elseif TrainingParams.DelayEnable
0172             %Inline dot
0173             %DEDTau(CurrentInputNeurons(m),CurrentNeuron)=dot(del{CurrentNeuron},sum(Z{CurrentNeuron}(K2Q,:),1));
0174             DEDTau(CurrentInputNeurons(m),CurrentNeuron)=sum(del{CurrentNeuron}.*sum(Z{CurrentNeuron}(K2Q,:),1));
0175         end
0176         
0177         if (Verbose==1)
0178             disp(['Gradient Between Neurons ' num2str(CurrentInputNeurons(m)) ' and ' num2str(CurrentNeuron)]);
0179             disp(['Weight Calculated: ' num2str(dot(del{CurrentNeuron},Y{CurrentNeuron}(CurrentInputNeurons(m),:)))]);
0180             disp(['Weight  Numerical: ' num2str(numerical_gradient(ThNN,CurrentInputNeurons(m),CurrentNeuron,tiCurrent,tdCurrent,GradientStep,'Weights'))]);
0181             if TrainingParams.DelayEnable
0182                 disp(['Delay Calculated: ' num2str(dot(del{CurrentNeuron},sum(Z{CurrentNeuron}(K2Q,:),1)))]);
0183                 disp(['Delay  Numerical: ' num2str(numerical_gradient(ThNN,CurrentInputNeurons(m),CurrentNeuron,tiCurrent,tdCurrent,GradientStep,'Delays'))]);
0184             end            
0185         end
0186     end
0187 end
0188 
0189 
0190 
0191 %The mission is to the calculate the dels for the inputs to the output neurons
0192 %Get list of inputs to output neurons and put them in the Queue
0193 %Update Queue Based on New Results (Add Neurons Connected to the
0194 %output of Current Neuron that are not already in the queue)
0195 for j=1:length(NeuronQueue)
0196     [NeuronQueue, CurrentNeuron]=pop(NeuronQueue);
0197     NeuronQueue=push_unique(NeuronQueue,RelativeInputNeurons{CurrentNeuron});
0198 end
0199 
0200 %Exclude Input Neurons
0201 InputNeurons=get_input_neurons(ThNN);
0202 %Inline setdiff
0203 %NeuronQueue2=setdiff(NeuronQueue, InputNeurons)
0204 for sd=1:length(InputNeurons)
0205     NeuronQueue=NeuronQueue(~(NeuronQueue==InputNeurons(sd)));
0206 end
0207 
0208 
0209 
0210 %For each output spike produced by the "hidden" neuron (k), we need to sum
0211 %up the effects over all the output neurons it is connected to (p)
0212 %Each effect is dot(del{p},Z{p}(q,:))
0213 
0214 while ~isempty(NeuronQueue)
0215     %Determine Current Neuron
0216     CurrentNeuron=NeuronQueue(1); %3
0217     %Determine The list of output neurons it is connected to (p)
0218     CurrentOutputNeurons=RelativeOutputNeurons{CurrentNeuron}; %4
0219     %Determine the number of output spikes that it produces (k)
0220     NumberOfSpikes=length(ts{CurrentNeuron}); %1
0221     %Calculate the dels for non-output neurons
0222     for p=1:length(CurrentOutputNeurons)
0223         %Determine q_k
0224         K2Q=K2QAll{CurrentNeuron,CurrentOutputNeurons(p)};
0225         for k=1:NumberOfSpikes
0226             %Inline dot
0227             %del{CurrentNeuron}(k)=dot(del{CurrentOutputNeurons(p)},Z{CurrentOutputNeurons(p)}(K2Q(k),:));
0228             del{CurrentNeuron}(k)=sum(del{CurrentOutputNeurons(p)}.*Z{CurrentOutputNeurons(p)}(K2Q(k),:));
0229         end
0230     end
0231 
0232     %Cycle again through network to calculate
0233     %For Neuron 3, DEDW between Neurons 1 and 3
0234     %Want Y between 1 an 3
0235     CurrentInputNeurons=RelativeInputNeurons{CurrentNeuron};
0236     for p=1:length(CurrentInputNeurons)
0237         if ~isempty(del{CurrentNeuron}) % Like would be the case if the hidden neuron did not fire
0238             K2Q=K2QAll{CurrentInputNeurons(p),CurrentNeuron};
0239             if NumericalGradient
0240                 DEDW(CurrentInputNeurons(p),CurrentNeuron)=numerical_gradient(ThNN,...
0241                     CurrentInputNeurons(p),CurrentNeuron,tiCurrent,tdCurrent,GradientStep,'Weights');
0242             else
0243                 %Inline dot
0244                 %DEDW(CurrentInputNeurons(p),CurrentNeuron)=dot(del{CurrentNeuron},Y{CurrentNeuron}(CurrentInputNeurons(p),:));
0245                 DEDW(CurrentInputNeurons(p),CurrentNeuron)=sum(del{CurrentNeuron}.*Y{CurrentNeuron}(CurrentInputNeurons(p),:));
0246             end
0247             if NumericalGradient && TrainingParams.DelayEnable
0248                 DEDTau(CurrentInputNeurons(p),CurrentNeuron)=numerical_gradient(ThNN,...
0249                     CurrentInputNeurons(p),CurrentNeuron,tiCurrent,tdCurrent,GradientStep,'Delays');
0250             elseif TrainingParams.DelayEnable
0251                 %Inline dot
0252                 %DEDTau(CurrentInputNeurons(p),CurrentNeuron)=dot(del{CurrentNeuron},sum(Z{CurrentNeuron}(K2Q,:),1));
0253                 DEDTau(CurrentInputNeurons(p),CurrentNeuron)=sum(del{CurrentNeuron}.*sum(Z{CurrentNeuron}(K2Q,:),1));
0254             end
0255             
0256             if (Verbose==1)
0257                 disp(['Gradient Between Neurons ' num2str(CurrentInputNeurons(p)) ' and ' num2str(CurrentNeuron)]);
0258                 disp(['Weight Calculated: ' num2str(dot(del{CurrentNeuron},Y{CurrentNeuron}(CurrentInputNeurons(p),:)))]);
0259                 disp(['Weight  Numerical: ' num2str(numerical_gradient(ThNN,CurrentInputNeurons(p),CurrentNeuron,tiCurrent,tdCurrent,GradientStep,'Weights'))]); 
0260                 if TrainingParams.DelayEnable
0261                     disp(['Delay Calculated: ' num2str(dot(del{CurrentNeuron},sum(Z{CurrentNeuron}(K2Q,:),1)))]);
0262                     disp(['Delay  Numerical: ' num2str(numerical_gradient(ThNN,CurrentInputNeurons(p),CurrentNeuron,tiCurrent,tdCurrent,GradientStep,'Delays'))]);
0263                 end
0264             end
0265         end
0266     end
0267 
0268     %Update Queue Based on New Results (Add Neurons Connected to the
0269     %output of Current Neuron that are not already in the queue)
0270     NeuronQueue=pop(NeuronQueue);
0271     %Inline setdiff
0272     %Temp=setdiff(CurrentInputNeurons, InputNeurons)
0273     Temp=CurrentInputNeurons;
0274     for n=1:length(InputNeurons)
0275         Temp=Temp(~(Temp==InputNeurons(n)));
0276     end
0277 
0278     
0279     NeuronQueue=push_unique(NeuronQueue,Temp);
0280 end
0281 
0282 
0283 TrainingResults.DEDW=DEDW;
0284 TrainingResults.DEDTau=DEDTau;
0285 TrainingResults.SSE=CurrentSSE;
0286 
0287 
0288 if (Verbose==1)
0289     disp(' ');
0290     disp('Gradient Calculation Results:');
0291     del
0292     DEDW
0293     DEDTau
0294     disp(' ');
0295 end
0296 
0297 return;

Generated on Wed 02-Apr-2008 15:16:32 by m2html © 2003