1 namespace SpikingNeuronNetwork.Lib.Training
5 using System.Collections.Generic;
42 OriginalNetwork = spikingNeuronNetwork,
43 TrainingSet = trainingSet,
44 LearningRate = learningRate,
45 MaximumErrorAfterTraining = maximumErrorAfterTraining,
46 MaximumNumberOfTrainingEpochs = maximumNumberOfTrainingEpochs,
47 ErrorCalculator = errorCalculator,
49 TrainingAlgorithm = trainingAlgorithm
51 TrainingStats.ResetTraining();
62 var neuronDerivativeParameters =
new Dictionary<Synapse, NeuronDerivativeParameters>();
63 foreach (var value
in neuronFiringHistories.Values)
64 neuronDerivativeParameters = neuronDerivativeParameters.Concat(TrainingStats.CurrentNetwork.CalculateOutputSpikeTimeDerivatives(value)).ToDictionary(x => x.Key, x => x.Value);
65 TrainingStats.ErrorCalculator.SimulationMethod = SimulationMethod.EventDriven;
67 return CalculateErrorDerivativeHelper(neuronFiringHistories, desiredOutputSpikeTimes,
68 neuronDerivativeParameters);
79 var neuronDerivativeParameters =
new Dictionary<Synapse, NeuronDerivativeParameters>();
80 foreach (var value
in neuronFiringHistories.Values)
81 neuronDerivativeParameters = neuronDerivativeParameters.Concat(TrainingStats.CurrentNetwork.CalculateOutputSpikeTimeDerivativesNumerical(value)).ToDictionary(x => x.Key, x => x.Value);
82 TrainingStats.ErrorCalculator.SimulationMethod = SimulationMethod.Numerical;
84 return CalculateErrorDerivativeHelper(neuronFiringHistories, desiredOutputSpikeTimes,
85 neuronDerivativeParameters);
96 TrainingStats.ResetTraining();
106 for (epoch = 1; epoch < TrainingStats.MaximumNumberOfTrainingEpochs; epoch++)
108 Console.WriteLine(
"Epoch: " + epoch);
109 var trainingsetNumber = 0;
110 var currentError = 0.0;
111 var nonFiringNeuronIndices =
new List<int>();
116 TrainingStats.CurrentNetwork.ResetNetwork();
117 var neuronFiringHistories = TrainingStats.CurrentNetwork.RunSpikingNeuronNetwork(trainingPattern.InputSpikes);
118 nonFiringNeuronIndices.AddRange(NeuronFiringHistory.GetNonFiringNeuronIndices(neuronFiringHistories));
119 var errorDerivatives = CalculateErrorDerivative(neuronFiringHistories, trainingPattern.OutputSpikes);
123 var currentOutputSpikes = trainingPattern.OutputSpikes.Where(x => x.NeuronIndex == outputNeuronIndex).ToList();
124 currentError += TrainingStats.ErrorCalculator.GetError(currentOutputSpikes, neuronFiringHistories[outputNeuronIndex].OutputSpikes);
129 TrainingStats.TrainingAlgorithm.UpdateWeights(
TrainingStats, errorDerivatives, previousErrorDerivatives);
131 else if (errorDerivatives != null)
133 if (errorDerivativesBatch == null)
139 foreach (var synapse
in errorDerivatives.GetSynapses())
141 errorDerivativesBatch.AddErrorDerivative(synapse, errorDerivatives.GetErrorDerivative(synapse));
146 previousErrorDerivatives = errorDerivatives;
151 TrainingStats.TrainingAlgorithm.UpdateWeights(
TrainingStats, errorDerivativesBatch, previousErrorDerivativesBatch);
152 previousErrorDerivativesBatch = errorDerivativesBatch;
154 nonFiringNeuronIndices = nonFiringNeuronIndices.Distinct().OrderBy(x => x).ToList();
155 TrainingStats.PerEpochStats.Add(TrainingStats.NumEpochs + epoch,
new TrainingStatsPerEpoch(currentError, nonFiringNeuronIndices));
156 Console.WriteLine(
"\tError on training set " + trainingsetNumber +
": " + currentError);
158 TrainingStats.NumEpochs += epoch;
161 TrainingStats.NumEpochs += epoch;
165 #region Helper Methods
174 private ErrorDerivativeParameters CalculateErrorDerivativeHelper(IReadOnlyDictionary<int, NeuronFiringHistory> neuronFiringHistories, List<Spike> desiredOutputSpikeTimes, IReadOnlyDictionary<Synapse, NeuronDerivativeParameters> neuronDerivativeParameters)
178 #region Output Neurons
181 if (desiredOutputSpikeTimes.Count(x => x.NeuronIndex == outputNeuronIndex) > 1)
183 throw new ArgumentException(
"MIMO Training Not Implemented");
186 if (neuronFiringHistories[outputNeuronIndex].OutputSpikes.Count == 0 || desiredOutputSpikeTimes.Count(x => x.NeuronIndex == outputNeuronIndex) == 0)
188 throw new ArgumentException(
"Actual or Desired Output Spike Times are Missing");
191 errorDerivativeParameters.ErrorToOutputSpikeTimeDerivatives.Add(
193 TrainingStats.ErrorCalculator.GetErrorToOutputSpikeTimeDerivative(
194 desiredOutputSpikeTimes.Where(x => x.NeuronIndex == outputNeuronIndex).ToList(),
195 neuronFiringHistories[outputNeuronIndex].OutputSpikes).First());
198 foreach (var synapse
in TrainingStats.CurrentNetwork.GetNeuronsInputSynapses(outputNeuronIndex).Where(neuronDerivativeParameters.ContainsKey))
200 errorDerivativeParameters.OutputSpikeTimeToWeightDerivatives.Add(synapse, neuronDerivativeParameters[synapse].OutputSpikeTimeToWeightDerivative);
205 #region Hidden Neurons
206 foreach (var layerIndex
in Enumerable.Range(0, TrainingStats.CurrentNetwork.NumNeuronsPerLayer.Count - 1).Reverse())
208 foreach (var hiddenNeuronIndex
in TrainingStats.CurrentNetwork.GetNeuronIndicesByLayer(layerIndex))
210 var outputSynapses = TrainingStats.CurrentNetwork.GetNeuronsOutputSynapses(hiddenNeuronIndex);
211 errorDerivativeParameters.ErrorToOutputSpikeTimeDerivatives.Add(hiddenNeuronIndex, 0);
212 foreach (var outputSynapse
in outputSynapses)
216 errorDerivativeParameters.ErrorToOutputSpikeTimeDerivatives[hiddenNeuronIndex] +=
217 errorDerivativeParameters.ErrorToOutputSpikeTimeDerivatives[outputSynapse.OutputNeuronIndex] *
218 neuronDerivativeParameters[outputSynapse].OutputSpikeTimeToInputSpikeTimeDerivative;
220 foreach (var synapse
in TrainingStats.CurrentNetwork.GetNeuronsInputSynapses(hiddenNeuronIndex))
222 errorDerivativeParameters.OutputSpikeTimeToWeightDerivatives.Add(synapse, neuronDerivativeParameters[synapse].OutputSpikeTimeToWeightDerivative);
228 return errorDerivativeParameters;
SpikingNeuronNetwork CurrentNetwork
Gets or sets the current network after training
Training Algorithm Interface
TrainingMethod
Training Method Enum
bool IsLayered
Indicates if the network is layered
ErrorDerivativeParameters CalculateErrorDerivativeNumerical(Dictionary< int, NeuronFiringHistory > neuronFiringHistories, List< Spike > desiredOutputSpikeTimes)
Calculate the error derivative numerically
TrainNetwork(SpikingNeuronNetwork spikingNeuronNetwork, List< SpikeSet > trainingSet, double learningRate, double maximumErrorAfterTraining, double maximumNumberOfTrainingEpochs, ISpikingError errorCalculator, ITrainingAlgorithm trainingAlgorithm, TrainingMethod trainingMethod)
Creates a new instance the TrainNetwork class
Training Stats Per Epoch Class, inherits from a tuple representing the error and non-firing neuron in...
List< SpikeSet > TrainingSet
Gets or sets the training spike set, including input spikes and desired output spikes ...
TrainingMethod TrainingMethod
Gets or sets the training method
Error Derivative Parameters Batch Class For Tracking Running Totals in Error Derivatives ...
List< int > GetOutputLayerNeuronIndices()
Gets a list of output layer neuron indices
bool StartTraining()
Start training the network
Training Statistics Class
double MaximumErrorAfterTraining
Gets or sets the maximum error after training
Error Derivative Parameters Class
Spiking Neuron Network Class
TrainNetwork(TrainingStats trainingStats)
Creates a new instance the TrainNetwork class
ErrorDerivativeParameters CalculateErrorDerivative(Dictionary< int, NeuronFiringHistory > neuronFiringHistories, List< Spike > desiredOutputSpikeTimes)
Calculate the error derivative