GET_ERROR calculates the sum of squared error in output spike times Description: Function to calculate the sum of squared error given the actual and desired output spike times. ts is in the cell format produced by run_network. td is in the standard array format [NeuronIndex1 DesiredFiringTime1; ...]. The function attempts to make sure there is the correct amount of spikes to compare, but if Io<0 and there aren't enough output spikes, a function error may be produced. Syntax: [SSE,RMSE]=GET_ERROR(ThNN,ts,td); [SSE,RMSE]=GET_ERROR(ThNN,tiAll,tdAll); [SSE,RMSE]=GET_ERROR(ThNN,ts,td,[Mode],[Classes]); [SSE,RMSE]=GET_ERROR(ThNN,tiAll,tdAll,[Mode],[Classes]); Input Parameters: o ThNN: An object of the theta neuron network class o ts: A cell array of length NumNeurons containing in each cell an array of length r_j of neuron j's output spike times. This format is the one produced from the run_network function. o td: An rx2 array that contains the neuron indices (globally indexed to the network) for each spike time (column 1) and the desired output spike times (column 2). o tiAll: A cell array of length a, the number of input patterns, of which each cell is a qx2 array that contains the neuron indices for each spike time and the input spike times. q may vary from cell to cell. o tdAll: A cell array of length a, the number of input patterns, of which each cell is a (qo x 2) array that contains the output neuron indices for each desired output spike time and the desired output spike times. qo may vary from cell to cell. o Mode: Scalar to help indicate type of error computation and display. If Mode = 0 (the default value), SSE and RMSE are calculated and returned. If Mode = 1, SSE and RMSE are calculated, and the actual and desired spike times are displayed. If Mode = 2, the Classification Error is also calculated, using the optional input Classes (or calculating Classes if needed). Mode = 3 combines modes 1 and 2. Output Parameters: o SSE: A scalar indicating sum of squared errors between ts and td if the function is called with ts and td. Otherwise, SSE indicates the SSE between all the patterns produced from the outputs of tiAll and tdAll. o RMSE: A scalar indicating root mean squared error (the average error in each output spike) between ts and td if the function is called with ts and td. Otherwise, RMSE indicates the RMSE between all spikes in the patterns produced from the outputs of tiAll and tdAll. o ClassErr: Optional output that indicates the Classification Error if it is calculated (as determined by the input Mode). o tsAll: A cell array of length NumNeurons containing in each cell an array of length r_j of neuron j's output spike times. Examples: >> %Simple Error >> ThNN = theta_neuron_network; >> ts = run_network(ThNN, [2 3]) >> [SSE0, RMSE0] = get_error(ThNN,ts,[3 ts{3}]) >> [SSE1, RMSE1] = get_error(ThNN,ts,[3 25]) >> %Error over multiple spikes >> ThNN = theta_neuron_network; >> ThNN.Neurons(3).Io = 0.05; >> ts = run_network(ThNN, [2 6]) >> [SSE2, RMSE2] = get_error(ThNN,ts,[3 20; 3 40]) >> %Error over multiple input patterns >> ThNN = theta_neuron_network; >> [SSE3, RMSE3] = get_error(ThNN,{[2 3],[2 6]},{[3 25],[3 20]}) >> %Error over multiple spikes and input patterns (with Display) >> ThNN = theta_neuron_network; >> ThNN.Neurons(3).Io = 0.05; >> [SSE4, RMSE4] = get_error(ThNN,{[2 3],[2 6]},{[3 25],[3 20; 3 40]},1) >> %Classification Error (with Display), classes are calculated >> ThNN = theta_neuron_network; >> [SSE5, RMSE5] = get_error(ThNN,{[2 3],[2 6]},{[3 25],[3 20]},3) See also theta_neuron_network
0001 function [SSE, RMSE, ClassErr, tsAll]=get_error(ThNN,varargin) 0002 %GET_ERROR calculates the sum of squared error in output spike times 0003 % 0004 %Description: 0005 %Function to calculate the sum of squared error given the actual and 0006 %desired output spike times. ts is in the cell format produced by 0007 %run_network. td is in the standard array format [NeuronIndex1 0008 %DesiredFiringTime1; ...]. The function attempts to make sure there is the 0009 %correct amount of spikes to compare, but if Io<0 and there aren't enough 0010 %output spikes, a function error may be produced. 0011 % 0012 %Syntax: 0013 %[SSE,RMSE]=GET_ERROR(ThNN,ts,td); 0014 %[SSE,RMSE]=GET_ERROR(ThNN,tiAll,tdAll); 0015 %[SSE,RMSE]=GET_ERROR(ThNN,ts,td,[Mode],[Classes]); 0016 %[SSE,RMSE]=GET_ERROR(ThNN,tiAll,tdAll,[Mode],[Classes]); 0017 % 0018 %Input Parameters: 0019 %o ThNN: An object of the theta neuron network class 0020 %o ts: A cell array of length NumNeurons containing in each cell an array 0021 % of length r_j of neuron j's output spike times. This format is the one 0022 % produced from the run_network function. 0023 %o td: An rx2 array that contains the neuron indices (globally indexed to 0024 % the network) for each spike time (column 1) and the desired output 0025 % spike times (column 2). 0026 %o tiAll: A cell array of length a, the number of input patterns, of which 0027 % each cell is a qx2 array that contains the neuron indices for each 0028 % spike time and the input spike times. q may vary from cell to cell. 0029 %o tdAll: A cell array of length a, the number of input patterns, of which 0030 % each cell is a (qo x 2) array that contains the output neuron indices 0031 % for each desired output spike time and the desired output spike times. 0032 % qo may vary from cell to cell. 0033 %o Mode: Scalar to help indicate type of error computation and 0034 % display. If Mode = 0 (the default value), SSE and RMSE are calculated 0035 % and returned. If Mode = 1, SSE and RMSE are calculated, and the 0036 % actual and desired spike times are displayed. If Mode = 2, the 0037 % Classification Error is also calculated, using the optional input 0038 % Classes (or calculating Classes if needed). Mode = 3 combines modes 1 0039 % and 2. 0040 % 0041 %Output Parameters: 0042 %o SSE: A scalar indicating sum of squared errors between ts and td if the 0043 % function is called with ts and td. Otherwise, SSE indicates the SSE 0044 % between all the patterns produced from the outputs of tiAll and tdAll. 0045 %o RMSE: A scalar indicating root mean squared error (the average error in 0046 % each output spike) between ts and td if the function is called with ts 0047 % and td. Otherwise, RMSE indicates the RMSE between all spikes in the 0048 % patterns produced from the outputs of tiAll and tdAll. 0049 %o ClassErr: Optional output that indicates the Classification Error if it 0050 % is calculated (as determined by the input Mode). 0051 %o tsAll: A cell array of length NumNeurons containing in each cell an 0052 % array of length r_j of neuron j's output spike times. 0053 % 0054 %Examples: 0055 %>> %Simple Error 0056 %>> ThNN = theta_neuron_network; 0057 %>> ts = run_network(ThNN, [2 3]) 0058 %>> [SSE0, RMSE0] = get_error(ThNN,ts,[3 ts{3}]) 0059 %>> [SSE1, RMSE1] = get_error(ThNN,ts,[3 25]) 0060 % 0061 %>> %Error over multiple spikes 0062 %>> ThNN = theta_neuron_network; 0063 %>> ThNN.Neurons(3).Io = 0.05; 0064 %>> ts = run_network(ThNN, [2 6]) 0065 %>> [SSE2, RMSE2] = get_error(ThNN,ts,[3 20; 3 40]) 0066 % 0067 %>> %Error over multiple input patterns 0068 %>> ThNN = theta_neuron_network; 0069 %>> [SSE3, RMSE3] = get_error(ThNN,{[2 3],[2 6]},{[3 25],[3 20]}) 0070 % 0071 %>> %Error over multiple spikes and input patterns (with Display) 0072 %>> ThNN = theta_neuron_network; 0073 %>> ThNN.Neurons(3).Io = 0.05; 0074 %>> [SSE4, RMSE4] = get_error(ThNN,{[2 3],[2 6]},{[3 25],[3 20; 3 40]},1) 0075 % 0076 %>> %Classification Error (with Display), classes are calculated 0077 %>> ThNN = theta_neuron_network; 0078 %>> [SSE5, RMSE5] = get_error(ThNN,{[2 3],[2 6]},{[3 25],[3 20]},3) 0079 % 0080 %See also theta_neuron_network 0081 0082 %Copyright (C) 2008 Sam McKennoch <Samuel.McKennoch@loria.fr> 0083 0084 0085 SSE=-1; RMSE=-1; ClassErr=-1; 0086 0087 if nargin<3 0088 disp('Error in numerical_gradient: Wrong number of input arguements'); 0089 disp(['Needed at least 3 inputs but got ' num2str(nargin)]); 0090 return; 0091 end 0092 0093 if nargin==3 0094 Mode=0; 0095 else 0096 Mode=varargin{3}; 0097 end 0098 0099 if (Mode == 2 || Mode == 3) 0100 OutputNeurons=get_output_neurons(ThNN); 0101 ClassErr=0; 0102 end 0103 0104 if iscell(varargin{2}) 0105 tiAll=varargin{1}; 0106 tdAll=varargin{2}; 0107 if (Mode == 2 || Mode == 3) && nargin==5 0108 Classes=varargin{4}; 0109 else 0110 Classes=[]; 0111 for k=1:length(tdAll) 0112 for m=1:size(tdAll{k},1) 0113 Classes=push_unique(Classes,tdAll{k}(:,2)'); 0114 end 0115 end 0116 Classes=sort(Classes); 0117 end 0118 SSE = 0; RMSE = 0; NumSpikes=0; 0119 0120 tsAll=cell(length(tiAll),1); 0121 for j=1:length(tiAll) 0122 tsAll{j} = run_network(ThNN,tiAll{j}); 0123 [ErrorTemp, RMSETemp, tsAll{j}] = get_error_helper(ThNN,tsAll{j},tdAll{j}); 0124 SSE = SSE + ErrorTemp; 0125 NumSpikes = NumSpikes + size(tdAll{j},1); 0126 if (Mode == 2 || Mode == 3) 0127 if length(tsAll{j}{OutputNeurons(1)})>1 0128 %Maybe could do MIMO pattern matching for 0129 %classification though? 0130 disp('Problem B1!'); 0131 end 0132 for m=1:length(Classes) 0133 ClassDistance=abs(tsAll{j}{OutputNeurons(1)}(1)-Classes(m)); 0134 if m>1 && ClassDistance>ClassDistanceOld 0135 CurrentClass=Classes(m-1); 0136 break; 0137 end 0138 CurrentClass=Classes(m); %Match the final class 0139 ClassDistanceOld=ClassDistance; 0140 end 0141 %Desired 0142 if CurrentClass~=tdAll{j}(1,2) 0143 ClassErr=ClassErr+1; 0144 end 0145 end 0146 end 0147 if (Mode == 2 || Mode == 3) 0148 ClassErr=ClassErr/length(tiAll); 0149 end 0150 if Mode==1 || Mode==3 0151 display_results(ThNN,'Outputs',tsAll,tdAll); 0152 end 0153 RMSE = sqrt(2*SSE/NumSpikes); 0154 else 0155 ts=varargin{1}; 0156 td=varargin{2}; 0157 if (Mode == 2 || Mode == 3) && nargin==5 0158 Classes=varargin{4}; 0159 if length(ts{OutputNeurons(1)})>1 0160 %Maybe could do MIMO pattern matching for 0161 %classification though? 0162 disp('Problem B2!'); 0163 end 0164 for m=1:length(Classes) 0165 ClassDistance=abs(ts{OutputNeurons(1)}(1)-Classes(m)); 0166 if m>1 && ClassDistance>ClassDistanceOld 0167 CurrentClass=Classes(m-1); 0168 break; 0169 end 0170 CurrentClass=Classes(m); %Match the final class 0171 ClassDistanceOld=ClassDistance; 0172 end 0173 %Desired 0174 if CurrentClass~=td(1,2) 0175 ClassErr=ClassErr+1; 0176 end 0177 else 0178 %Can't calculate Classes here, since don't have all the info so 0179 %reseting the Mode if needed 0180 %disp('Problem C!'); 0181 if Mode>1 0182 Mode=Mode-2; 0183 end 0184 end 0185 ThNN 0186 ts 0187 td 0188 [SSE, RMSE, ts] = get_error_helper(ThNN,ts,td); 0189 if Mode==1 || Mode==3 0190 display_results(ThNN,'Outputs',{ts},{td}); 0191 end 0192 end 0193 return; 0194 0195 0196 %Helper Function to Calculate Error For a Single Pattern 0197 function [SSE, RMSE, ts] = get_error_helper(ThNN,ts,td) 0198 0199 %Reformat td to be {NeuronIndex1, [SpikeTimes1]; NeuronIndex2,[SpikeTimes2];...} 0200 tdC=cell(size(ts)); 0201 for j=1:size(td,1) 0202 CurrentIndex=td(j,1); 0203 tdC{CurrentIndex}=[tdC{CurrentIndex}, td(j,2)]; 0204 end 0205 for j=1:length(tdC) 0206 tdC{j}=sort(tdC{j}); 0207 end 0208 0209 %Go through each output neuron and compare 0210 NeuronList=unique(td(:,1)); 0211 SSE=0; 0212 for j=1:length(NeuronList) 0213 CurrentNeuron=NeuronList(j); 0214 if length(ts{CurrentNeuron})<length(tdC{CurrentNeuron}) && ThNN.Neurons(CurrentNeuron).Io>0 0215 for k=(length(ts{CurrentNeuron})+1):length(tdC{CurrentNeuron}) 0216 ts{CurrentNeuron}(1,end+1)=ts{CurrentNeuron}(1,end)+(pi/ThNN.Neurons(CurrentNeuron).Beta); 0217 end 0218 elseif length(ts{CurrentNeuron})>length(tdC{CurrentNeuron}) 0219 ts{CurrentNeuron}=ts{CurrentNeuron}(1:length(tdC{CurrentNeuron})); 0220 elseif length(ts{CurrentNeuron})~=length(tdC{CurrentNeuron}) %Error is undefined as the number of desired and actual spikes are unreconcilable 0221 disp('Error in get_error: The number of actual and desired spikes are unreconcilable'); 0222 SSE=-1; 0223 return; 0224 end 0225 %Equal Numbers of Output Spikes Now, Process Normally 0226 SSE=SSE+sum(0.5*((ts{CurrentNeuron}-tdC{CurrentNeuron}).^2)); 0227 end 0228 RMSE=sqrt(2*SSE/size(td,1)); %Error per spike