Home > TNNT_1_07 > @theta_neuron_network > get_error.m

get_error

PURPOSE ^

GET_ERROR calculates the sum of squared error in output spike times

SYNOPSIS ^

function [SSE, RMSE, ClassErr, tsAll]=get_error(ThNN,varargin)

DESCRIPTION ^

GET_ERROR calculates the sum of squared error in output spike times

Description:
Function to calculate the sum of squared error given the actual and
desired output spike times.  ts is in the cell format produced by
run_network. td is in the standard array format [NeuronIndex1
DesiredFiringTime1; ...]. The function attempts to make sure there is the
correct amount of spikes to compare, but if Io<0 and there aren't enough
output spikes, a function error may be produced.

Syntax:
[SSE,RMSE]=GET_ERROR(ThNN,ts,td);
[SSE,RMSE]=GET_ERROR(ThNN,tiAll,tdAll);
[SSE,RMSE]=GET_ERROR(ThNN,ts,td,[Mode],[Classes]);
[SSE,RMSE]=GET_ERROR(ThNN,tiAll,tdAll,[Mode],[Classes]);

Input Parameters:
o ThNN: An object of the theta neuron network class
o ts: A cell array of length NumNeurons containing in each cell an array 
    of length r_j of neuron j's output spike times. This format is the one
    produced from the run_network function.
o td: An rx2 array that contains the neuron indices (globally indexed to
    the network) for each spike time (column 1) and the desired output
    spike times (column 2).
o tiAll: A cell array of length a, the number of input patterns, of which
    each cell is a qx2 array that contains the neuron indices for each 
    spike time and the input spike times.  q may vary from cell to cell.
o tdAll: A cell array of length a, the number of input patterns, of which 
    each cell is a (qo x 2) array that contains the output neuron indices 
    for each desired output spike time and the desired output spike times.
    qo may vary from cell to cell.
o Mode: Scalar to help indicate type of error computation and 
    display. If Mode = 0 (the default value), SSE and RMSE are calculated
    and returned.  If Mode = 1, SSE and RMSE are calculated, and the
    actual and desired spike times are displayed. If Mode = 2, the
    Classification Error is also calculated, using the optional input
    Classes (or calculating Classes if needed). Mode = 3 combines modes 1
    and 2.

Output Parameters:
o SSE: A scalar indicating sum of squared errors between ts and td if the
    function is called with ts and td. Otherwise, SSE indicates the SSE
    between all the patterns produced from the outputs of tiAll and tdAll.
o RMSE: A scalar indicating root mean squared error (the average error in 
    each output spike) between ts and td if the function is called with ts
    and td. Otherwise, RMSE indicates the RMSE between all spikes in the 
    patterns produced from the outputs of tiAll and tdAll.
o ClassErr: Optional output that indicates the Classification Error if it
    is calculated (as determined by the input Mode).
o tsAll: A cell array of length NumNeurons containing in each cell an 
    array of length r_j of neuron j's output spike times.

Examples:
>> %Simple Error
>> ThNN = theta_neuron_network;
>> ts = run_network(ThNN, [2 3])
>> [SSE0, RMSE0] = get_error(ThNN,ts,[3 ts{3}])
>> [SSE1, RMSE1] = get_error(ThNN,ts,[3 25])

>> %Error over multiple spikes
>> ThNN = theta_neuron_network;
>> ThNN.Neurons(3).Io = 0.05;
>> ts = run_network(ThNN, [2 6])
>> [SSE2, RMSE2] = get_error(ThNN,ts,[3 20; 3 40])

>> %Error over multiple input patterns
>> ThNN = theta_neuron_network;
>> [SSE3, RMSE3] = get_error(ThNN,{[2 3],[2 6]},{[3 25],[3 20]})

>> %Error over multiple spikes and input patterns (with Display)
>> ThNN = theta_neuron_network;
>> ThNN.Neurons(3).Io = 0.05;
>> [SSE4, RMSE4] = get_error(ThNN,{[2 3],[2 6]},{[3 25],[3 20; 3 40]},1)

>> %Classification Error (with Display), classes are calculated
>> ThNN = theta_neuron_network;
>> [SSE5, RMSE5] = get_error(ThNN,{[2 3],[2 6]},{[3 25],[3 20]},3)

See also theta_neuron_network

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [SSE, RMSE, ClassErr, tsAll]=get_error(ThNN,varargin)
0002 %GET_ERROR calculates the sum of squared error in output spike times
0003 %
0004 %Description:
0005 %Function to calculate the sum of squared error given the actual and
0006 %desired output spike times.  ts is in the cell format produced by
0007 %run_network. td is in the standard array format [NeuronIndex1
0008 %DesiredFiringTime1; ...]. The function attempts to make sure there is the
0009 %correct amount of spikes to compare, but if Io<0 and there aren't enough
0010 %output spikes, a function error may be produced.
0011 %
0012 %Syntax:
0013 %[SSE,RMSE]=GET_ERROR(ThNN,ts,td);
0014 %[SSE,RMSE]=GET_ERROR(ThNN,tiAll,tdAll);
0015 %[SSE,RMSE]=GET_ERROR(ThNN,ts,td,[Mode],[Classes]);
0016 %[SSE,RMSE]=GET_ERROR(ThNN,tiAll,tdAll,[Mode],[Classes]);
0017 %
0018 %Input Parameters:
0019 %o ThNN: An object of the theta neuron network class
0020 %o ts: A cell array of length NumNeurons containing in each cell an array
0021 %    of length r_j of neuron j's output spike times. This format is the one
0022 %    produced from the run_network function.
0023 %o td: An rx2 array that contains the neuron indices (globally indexed to
0024 %    the network) for each spike time (column 1) and the desired output
0025 %    spike times (column 2).
0026 %o tiAll: A cell array of length a, the number of input patterns, of which
0027 %    each cell is a qx2 array that contains the neuron indices for each
0028 %    spike time and the input spike times.  q may vary from cell to cell.
0029 %o tdAll: A cell array of length a, the number of input patterns, of which
0030 %    each cell is a (qo x 2) array that contains the output neuron indices
0031 %    for each desired output spike time and the desired output spike times.
0032 %    qo may vary from cell to cell.
0033 %o Mode: Scalar to help indicate type of error computation and
0034 %    display. If Mode = 0 (the default value), SSE and RMSE are calculated
0035 %    and returned.  If Mode = 1, SSE and RMSE are calculated, and the
0036 %    actual and desired spike times are displayed. If Mode = 2, the
0037 %    Classification Error is also calculated, using the optional input
0038 %    Classes (or calculating Classes if needed). Mode = 3 combines modes 1
0039 %    and 2.
0040 %
0041 %Output Parameters:
0042 %o SSE: A scalar indicating sum of squared errors between ts and td if the
0043 %    function is called with ts and td. Otherwise, SSE indicates the SSE
0044 %    between all the patterns produced from the outputs of tiAll and tdAll.
0045 %o RMSE: A scalar indicating root mean squared error (the average error in
0046 %    each output spike) between ts and td if the function is called with ts
0047 %    and td. Otherwise, RMSE indicates the RMSE between all spikes in the
0048 %    patterns produced from the outputs of tiAll and tdAll.
0049 %o ClassErr: Optional output that indicates the Classification Error if it
0050 %    is calculated (as determined by the input Mode).
0051 %o tsAll: A cell array of length NumNeurons containing in each cell an
0052 %    array of length r_j of neuron j's output spike times.
0053 %
0054 %Examples:
0055 %>> %Simple Error
0056 %>> ThNN = theta_neuron_network;
0057 %>> ts = run_network(ThNN, [2 3])
0058 %>> [SSE0, RMSE0] = get_error(ThNN,ts,[3 ts{3}])
0059 %>> [SSE1, RMSE1] = get_error(ThNN,ts,[3 25])
0060 %
0061 %>> %Error over multiple spikes
0062 %>> ThNN = theta_neuron_network;
0063 %>> ThNN.Neurons(3).Io = 0.05;
0064 %>> ts = run_network(ThNN, [2 6])
0065 %>> [SSE2, RMSE2] = get_error(ThNN,ts,[3 20; 3 40])
0066 %
0067 %>> %Error over multiple input patterns
0068 %>> ThNN = theta_neuron_network;
0069 %>> [SSE3, RMSE3] = get_error(ThNN,{[2 3],[2 6]},{[3 25],[3 20]})
0070 %
0071 %>> %Error over multiple spikes and input patterns (with Display)
0072 %>> ThNN = theta_neuron_network;
0073 %>> ThNN.Neurons(3).Io = 0.05;
0074 %>> [SSE4, RMSE4] = get_error(ThNN,{[2 3],[2 6]},{[3 25],[3 20; 3 40]},1)
0075 %
0076 %>> %Classification Error (with Display), classes are calculated
0077 %>> ThNN = theta_neuron_network;
0078 %>> [SSE5, RMSE5] = get_error(ThNN,{[2 3],[2 6]},{[3 25],[3 20]},3)
0079 %
0080 %See also theta_neuron_network
0081 
0082 %Copyright (C) 2008 Sam McKennoch <Samuel.McKennoch@loria.fr>
0083 
0084 
0085 SSE=-1; RMSE=-1; ClassErr=-1;
0086 
0087 if nargin<3
0088     disp('Error in numerical_gradient: Wrong number of input arguements');
0089     disp(['Needed at least 3 inputs but got ' num2str(nargin)]);
0090     return;
0091 end
0092 
0093 if nargin==3
0094     Mode=0;
0095 else
0096     Mode=varargin{3};
0097 end
0098 
0099 if (Mode == 2 || Mode == 3)
0100    OutputNeurons=get_output_neurons(ThNN); 
0101    ClassErr=0;
0102 end
0103 
0104 if iscell(varargin{2})
0105     tiAll=varargin{1};
0106     tdAll=varargin{2};
0107     if (Mode == 2 || Mode == 3) && nargin==5
0108         Classes=varargin{4};
0109     else
0110         Classes=[];
0111         for k=1:length(tdAll)
0112             for m=1:size(tdAll{k},1)
0113                 Classes=push_unique(Classes,tdAll{k}(:,2)');
0114             end
0115         end
0116         Classes=sort(Classes);
0117     end
0118     SSE = 0; RMSE = 0; NumSpikes=0;
0119     
0120     tsAll=cell(length(tiAll),1);
0121     for j=1:length(tiAll)
0122         tsAll{j} = run_network(ThNN,tiAll{j});
0123         [ErrorTemp, RMSETemp, tsAll{j}] = get_error_helper(ThNN,tsAll{j},tdAll{j});
0124         SSE = SSE + ErrorTemp;
0125         NumSpikes = NumSpikes + size(tdAll{j},1);
0126         if (Mode == 2 || Mode == 3)
0127             if length(tsAll{j}{OutputNeurons(1)})>1
0128                 %Maybe could do MIMO pattern matching for
0129                 %classification though?
0130                 disp('Problem B1!');
0131             end
0132             for m=1:length(Classes)
0133                 ClassDistance=abs(tsAll{j}{OutputNeurons(1)}(1)-Classes(m));
0134                 if m>1 && ClassDistance>ClassDistanceOld
0135                     CurrentClass=Classes(m-1);
0136                     break;
0137                 end
0138                 CurrentClass=Classes(m); %Match the final class
0139                 ClassDistanceOld=ClassDistance;
0140             end
0141             %Desired
0142             if CurrentClass~=tdAll{j}(1,2)
0143                 ClassErr=ClassErr+1;
0144             end
0145         end
0146     end
0147     if (Mode == 2 || Mode == 3)
0148         ClassErr=ClassErr/length(tiAll);
0149     end
0150     if Mode==1 || Mode==3
0151         display_results(ThNN,'Outputs',tsAll,tdAll);
0152     end
0153     RMSE = sqrt(2*SSE/NumSpikes);
0154 else
0155     ts=varargin{1};
0156     td=varargin{2};
0157     if (Mode == 2 || Mode == 3) && nargin==5
0158         Classes=varargin{4};
0159         if length(ts{OutputNeurons(1)})>1
0160             %Maybe could do MIMO pattern matching for
0161             %classification though?
0162             disp('Problem B2!');
0163         end
0164         for m=1:length(Classes)
0165             ClassDistance=abs(ts{OutputNeurons(1)}(1)-Classes(m));
0166             if m>1 && ClassDistance>ClassDistanceOld
0167                 CurrentClass=Classes(m-1);
0168                 break;
0169             end
0170             CurrentClass=Classes(m); %Match the final class
0171             ClassDistanceOld=ClassDistance;
0172         end
0173         %Desired
0174         if CurrentClass~=td(1,2)
0175             ClassErr=ClassErr+1;
0176         end
0177     else
0178         %Can't calculate Classes here, since don't have all the info so
0179         %reseting the Mode if needed
0180         %disp('Problem C!');
0181         if Mode>1
0182             Mode=Mode-2;
0183         end
0184     end
0185     ThNN
0186     ts
0187     td
0188     [SSE, RMSE, ts] = get_error_helper(ThNN,ts,td);
0189     if Mode==1 || Mode==3
0190         display_results(ThNN,'Outputs',{ts},{td});
0191     end    
0192 end
0193 return;
0194 
0195 
0196 %Helper Function to Calculate Error For a Single Pattern
0197 function [SSE, RMSE, ts] = get_error_helper(ThNN,ts,td)
0198 
0199 %Reformat td to be {NeuronIndex1, [SpikeTimes1]; NeuronIndex2,[SpikeTimes2];...}
0200 tdC=cell(size(ts));
0201 for j=1:size(td,1)
0202     CurrentIndex=td(j,1);
0203     tdC{CurrentIndex}=[tdC{CurrentIndex}, td(j,2)];
0204 end
0205 for j=1:length(tdC)
0206     tdC{j}=sort(tdC{j});
0207 end
0208 
0209 %Go through each output neuron and compare
0210 NeuronList=unique(td(:,1));
0211 SSE=0;
0212 for j=1:length(NeuronList)
0213     CurrentNeuron=NeuronList(j);
0214     if length(ts{CurrentNeuron})<length(tdC{CurrentNeuron}) && ThNN.Neurons(CurrentNeuron).Io>0
0215         for k=(length(ts{CurrentNeuron})+1):length(tdC{CurrentNeuron})
0216             ts{CurrentNeuron}(1,end+1)=ts{CurrentNeuron}(1,end)+(pi/ThNN.Neurons(CurrentNeuron).Beta);
0217         end
0218     elseif length(ts{CurrentNeuron})>length(tdC{CurrentNeuron})
0219         ts{CurrentNeuron}=ts{CurrentNeuron}(1:length(tdC{CurrentNeuron}));
0220     elseif length(ts{CurrentNeuron})~=length(tdC{CurrentNeuron}) %Error is undefined as the number of desired and actual spikes are unreconcilable
0221         disp('Error in get_error: The number of actual and desired spikes are unreconcilable');
0222         SSE=-1;
0223         return;
0224     end
0225     %Equal Numbers of Output Spikes Now, Process Normally
0226     SSE=SSE+sum(0.5*((ts{CurrentNeuron}-tdC{CurrentNeuron}).^2));
0227 end
0228 RMSE=sqrt(2*SSE/size(td,1)); %Error per spike

Generated on Wed 02-Apr-2008 15:16:32 by m2html © 2003