Spiking Neuron Network Simulator  1.0
Simulation and training of spiking neuron networks, primarily theta neurons
 All Classes Namespaces Files Functions Variables Enumerations Enumerator Properties Pages
TrainNetwork.cs
Go to the documentation of this file.
1 namespace SpikingNeuronNetwork.Lib.Training
2 {
3  using Interfaces;
4  using System;
5  using System.Collections.Generic;
6  using System.Linq;
7 
11  public class TrainNetwork
12  {
16  public TrainingStats TrainingStats { get; set; }
17 
22  public TrainNetwork(TrainingStats trainingStats)
23  {
24  TrainingStats = trainingStats;
25  }
26 
38  public TrainNetwork(SpikingNeuronNetwork spikingNeuronNetwork, List<SpikeSet> trainingSet, double learningRate, double maximumErrorAfterTraining, double maximumNumberOfTrainingEpochs, ISpikingError errorCalculator, ITrainingAlgorithm trainingAlgorithm, TrainingMethod trainingMethod)
39  {
41  {
42  OriginalNetwork = spikingNeuronNetwork,
43  TrainingSet = trainingSet,
44  LearningRate = learningRate,
45  MaximumErrorAfterTraining = maximumErrorAfterTraining,
46  MaximumNumberOfTrainingEpochs = maximumNumberOfTrainingEpochs,
47  ErrorCalculator = errorCalculator,
48  TrainingMethod = trainingMethod,
49  TrainingAlgorithm = trainingAlgorithm
50  };
51  TrainingStats.ResetTraining();
52  }
53 
60  public ErrorDerivativeParameters CalculateErrorDerivative(Dictionary<int, NeuronFiringHistory> neuronFiringHistories, List<Spike> desiredOutputSpikeTimes)
61  {
62  var neuronDerivativeParameters = new Dictionary<Synapse, NeuronDerivativeParameters>();
63  foreach (var value in neuronFiringHistories.Values)
64  neuronDerivativeParameters = neuronDerivativeParameters.Concat(TrainingStats.CurrentNetwork.CalculateOutputSpikeTimeDerivatives(value)).ToDictionary(x => x.Key, x => x.Value);
65  TrainingStats.ErrorCalculator.SimulationMethod = SimulationMethod.EventDriven;
66 
67  return CalculateErrorDerivativeHelper(neuronFiringHistories, desiredOutputSpikeTimes,
68  neuronDerivativeParameters);
69  }
70 
77  public ErrorDerivativeParameters CalculateErrorDerivativeNumerical(Dictionary<int, NeuronFiringHistory> neuronFiringHistories, List<Spike> desiredOutputSpikeTimes)
78  {
79  var neuronDerivativeParameters = new Dictionary<Synapse, NeuronDerivativeParameters>();
80  foreach (var value in neuronFiringHistories.Values)
81  neuronDerivativeParameters = neuronDerivativeParameters.Concat(TrainingStats.CurrentNetwork.CalculateOutputSpikeTimeDerivativesNumerical(value)).ToDictionary(x => x.Key, x => x.Value);
82  TrainingStats.ErrorCalculator.SimulationMethod = SimulationMethod.Numerical;
83 
84  return CalculateErrorDerivativeHelper(neuronFiringHistories, desiredOutputSpikeTimes,
85  neuronDerivativeParameters);
86  }
87 
92  public bool StartTraining()
93  {
94  if (TrainingStats.CurrentNetwork == null)
95  {
96  TrainingStats.ResetTraining();
97  }
98 
100  {
101  return false;
102  }
103 
104  ErrorDerivativeParametersBatch previousErrorDerivativesBatch = null;
105  int epoch;
106  for (epoch = 1; epoch < TrainingStats.MaximumNumberOfTrainingEpochs; epoch++)
107  {
108  Console.WriteLine("Epoch: " + epoch);
109  var trainingsetNumber = 0;
110  var currentError = 0.0;
111  var nonFiringNeuronIndices = new List<int>();
112  ErrorDerivativeParametersBatch errorDerivativesBatch = null;
113  ErrorDerivativeParameters previousErrorDerivatives = null;
114  foreach (var trainingPattern in TrainingStats.TrainingSet)
115  {
116  TrainingStats.CurrentNetwork.ResetNetwork();
117  var neuronFiringHistories = TrainingStats.CurrentNetwork.RunSpikingNeuronNetwork(trainingPattern.InputSpikes);
118  nonFiringNeuronIndices.AddRange(NeuronFiringHistory.GetNonFiringNeuronIndices(neuronFiringHistories));
119  var errorDerivatives = CalculateErrorDerivative(neuronFiringHistories, trainingPattern.OutputSpikes);
120 
121  foreach (var outputNeuronIndex in TrainingStats.CurrentNetwork.GetOutputLayerNeuronIndices())
122  {
123  var currentOutputSpikes = trainingPattern.OutputSpikes.Where(x => x.NeuronIndex == outputNeuronIndex).ToList();
124  currentError += TrainingStats.ErrorCalculator.GetError(currentOutputSpikes, neuronFiringHistories[outputNeuronIndex].OutputSpikes);
125  }
126 
128  {
129  TrainingStats.TrainingAlgorithm.UpdateWeights(TrainingStats, errorDerivatives, previousErrorDerivatives);
130  }
131  else if (errorDerivatives != null)
132  {
133  if (errorDerivativesBatch == null)
134  {
135  errorDerivativesBatch = new ErrorDerivativeParametersBatch(errorDerivatives);
136  }
137  else
138  {
139  foreach (var synapse in errorDerivatives.GetSynapses())
140  {
141  errorDerivativesBatch.AddErrorDerivative(synapse, errorDerivatives.GetErrorDerivative(synapse));
142  }
143  }
144  }
145 
146  previousErrorDerivatives = errorDerivatives;
147  trainingsetNumber++;
148  }
150  {
151  TrainingStats.TrainingAlgorithm.UpdateWeights(TrainingStats, errorDerivativesBatch, previousErrorDerivativesBatch);
152  previousErrorDerivativesBatch = errorDerivativesBatch;
153  }
154  nonFiringNeuronIndices = nonFiringNeuronIndices.Distinct().OrderBy(x => x).ToList();
155  TrainingStats.PerEpochStats.Add(TrainingStats.NumEpochs + epoch, new TrainingStatsPerEpoch(currentError, nonFiringNeuronIndices));
156  Console.WriteLine("\tError on training set " + trainingsetNumber + ": " + currentError);
157  if (currentError >= TrainingStats.MaximumErrorAfterTraining) continue;
158  TrainingStats.NumEpochs += epoch;
159  return true;
160  }
161  TrainingStats.NumEpochs += epoch;
162  return false;
163  }
164 
165  #region Helper Methods
166 
174  private ErrorDerivativeParameters CalculateErrorDerivativeHelper(IReadOnlyDictionary<int, NeuronFiringHistory> neuronFiringHistories, List<Spike> desiredOutputSpikeTimes, IReadOnlyDictionary<Synapse, NeuronDerivativeParameters> neuronDerivativeParameters)
175  {
176  var errorDerivativeParameters = new ErrorDerivativeParameters();
177 
178  #region Output Neurons
179  foreach (var outputNeuronIndex in TrainingStats.CurrentNetwork.GetOutputLayerNeuronIndices())
180  {
181  if (desiredOutputSpikeTimes.Count(x => x.NeuronIndex == outputNeuronIndex) > 1)
182  {
183  throw new ArgumentException("MIMO Training Not Implemented");
184  }
185 
186  if (neuronFiringHistories[outputNeuronIndex].OutputSpikes.Count == 0 || desiredOutputSpikeTimes.Count(x => x.NeuronIndex == outputNeuronIndex) == 0)
187  {
188  throw new ArgumentException("Actual or Desired Output Spike Times are Missing");
189  }
190 
191  errorDerivativeParameters.ErrorToOutputSpikeTimeDerivatives.Add(
192  outputNeuronIndex,
193  TrainingStats.ErrorCalculator.GetErrorToOutputSpikeTimeDerivative(
194  desiredOutputSpikeTimes.Where(x => x.NeuronIndex == outputNeuronIndex).ToList(),
195  neuronFiringHistories[outputNeuronIndex].OutputSpikes).First());
196 
197  // Where clause here is because not all input synapses may have spikes
198  foreach (var synapse in TrainingStats.CurrentNetwork.GetNeuronsInputSynapses(outputNeuronIndex).Where(neuronDerivativeParameters.ContainsKey))
199  {
200  errorDerivativeParameters.OutputSpikeTimeToWeightDerivatives.Add(synapse, neuronDerivativeParameters[synapse].OutputSpikeTimeToWeightDerivative);
201  }
202  }
203  #endregion
204 
205  #region Hidden Neurons
206  foreach (var layerIndex in Enumerable.Range(0, TrainingStats.CurrentNetwork.NumNeuronsPerLayer.Count - 1).Reverse()) // - 1 because output layer is already done
207  {
208  foreach (var hiddenNeuronIndex in TrainingStats.CurrentNetwork.GetNeuronIndicesByLayer(layerIndex))
209  {
210  var outputSynapses = TrainingStats.CurrentNetwork.GetNeuronsOutputSynapses(hiddenNeuronIndex);
211  errorDerivativeParameters.ErrorToOutputSpikeTimeDerivatives.Add(hiddenNeuronIndex, 0);
212  foreach (var outputSynapse in outputSynapses)
213  {
214  // ToDo: Confirm that OutputSpikeTimeToInputSpikeTimeDerivative reduces to 0 if
215  // ToDo: the output neuron's output spike happens before the input neuron's output spike
216  errorDerivativeParameters.ErrorToOutputSpikeTimeDerivatives[hiddenNeuronIndex] +=
217  errorDerivativeParameters.ErrorToOutputSpikeTimeDerivatives[outputSynapse.OutputNeuronIndex] *
218  neuronDerivativeParameters[outputSynapse].OutputSpikeTimeToInputSpikeTimeDerivative;
219  }
220  foreach (var synapse in TrainingStats.CurrentNetwork.GetNeuronsInputSynapses(hiddenNeuronIndex))
221  {
222  errorDerivativeParameters.OutputSpikeTimeToWeightDerivatives.Add(synapse, neuronDerivativeParameters[synapse].OutputSpikeTimeToWeightDerivative);
223  }
224  }
225  }
226  #endregion
227 
228  return errorDerivativeParameters;
229  }
230 
231  #endregion
232  }
233 }
SpikingNeuronNetwork CurrentNetwork
Gets or sets the current network after training
TrainingMethod
Training Method Enum
bool IsLayered
Indicates if the network is layered
ErrorDerivativeParameters CalculateErrorDerivativeNumerical(Dictionary< int, NeuronFiringHistory > neuronFiringHistories, List< Spike > desiredOutputSpikeTimes)
Calculate the error derivative numerically
Definition: TrainNetwork.cs:77
TrainNetwork(SpikingNeuronNetwork spikingNeuronNetwork, List< SpikeSet > trainingSet, double learningRate, double maximumErrorAfterTraining, double maximumNumberOfTrainingEpochs, ISpikingError errorCalculator, ITrainingAlgorithm trainingAlgorithm, TrainingMethod trainingMethod)
Creates a new instance the TrainNetwork class
Definition: TrainNetwork.cs:38
Training Stats Per Epoch Class, inherits from a tuple representing the error and non-firing neuron in...
List< SpikeSet > TrainingSet
Gets or sets the training spike set, including input spikes and desired output spikes ...
TrainingMethod TrainingMethod
Gets or sets the training method
Error Derivative Parameters Batch Class For Tracking Running Totals in Error Derivatives ...
List< int > GetOutputLayerNeuronIndices()
Gets a list of output layer neuron indices
bool StartTraining()
Start training the network
Definition: TrainNetwork.cs:92
double MaximumErrorAfterTraining
Gets or sets the maximum error after training
TrainNetwork(TrainingStats trainingStats)
Creates a new instance the TrainNetwork class
Definition: TrainNetwork.cs:22
ErrorDerivativeParameters CalculateErrorDerivative(Dictionary< int, NeuronFiringHistory > neuronFiringHistories, List< Spike > desiredOutputSpikeTimes)
Calculate the error derivative
Definition: TrainNetwork.cs:60