1 namespace SpikingNeuronNetwork.Lib.Training
 
    5     using System.Collections.Generic;
 
   42                     OriginalNetwork = spikingNeuronNetwork,
 
   43                     TrainingSet = trainingSet,
 
   44                     LearningRate = learningRate,
 
   45                     MaximumErrorAfterTraining = maximumErrorAfterTraining,
 
   46                     MaximumNumberOfTrainingEpochs = maximumNumberOfTrainingEpochs,
 
   47                     ErrorCalculator = errorCalculator,
 
   49                     TrainingAlgorithm = trainingAlgorithm
 
   51             TrainingStats.ResetTraining();
 
   62             var neuronDerivativeParameters = 
new Dictionary<Synapse, NeuronDerivativeParameters>();
 
   63             foreach (var value 
in neuronFiringHistories.Values)
 
   64                 neuronDerivativeParameters = neuronDerivativeParameters.Concat(TrainingStats.CurrentNetwork.CalculateOutputSpikeTimeDerivatives(value)).ToDictionary(x => x.Key, x => x.Value);
 
   65             TrainingStats.ErrorCalculator.SimulationMethod = SimulationMethod.EventDriven;
 
   67             return CalculateErrorDerivativeHelper(neuronFiringHistories, desiredOutputSpikeTimes,
 
   68                                                   neuronDerivativeParameters);
 
   79             var neuronDerivativeParameters = 
new Dictionary<Synapse, NeuronDerivativeParameters>();
 
   80             foreach (var value 
in neuronFiringHistories.Values)
 
   81                 neuronDerivativeParameters = neuronDerivativeParameters.Concat(TrainingStats.CurrentNetwork.CalculateOutputSpikeTimeDerivativesNumerical(value)).ToDictionary(x => x.Key, x => x.Value);
 
   82             TrainingStats.ErrorCalculator.SimulationMethod = SimulationMethod.Numerical;
 
   84             return CalculateErrorDerivativeHelper(neuronFiringHistories, desiredOutputSpikeTimes,
 
   85                                                   neuronDerivativeParameters);
 
   96                 TrainingStats.ResetTraining();
 
  106             for (epoch = 1; epoch < TrainingStats.MaximumNumberOfTrainingEpochs; epoch++)
 
  108                 Console.WriteLine(
"Epoch: " + epoch);
 
  109                 var trainingsetNumber = 0;
 
  110                 var currentError = 0.0;
 
  111                 var nonFiringNeuronIndices = 
new List<int>();
 
  116                     TrainingStats.CurrentNetwork.ResetNetwork();
 
  117                     var neuronFiringHistories = TrainingStats.CurrentNetwork.RunSpikingNeuronNetwork(trainingPattern.InputSpikes);
 
  118                     nonFiringNeuronIndices.AddRange(NeuronFiringHistory.GetNonFiringNeuronIndices(neuronFiringHistories));
 
  119                     var errorDerivatives = CalculateErrorDerivative(neuronFiringHistories, trainingPattern.OutputSpikes);
 
  123                         var currentOutputSpikes = trainingPattern.OutputSpikes.Where(x => x.NeuronIndex == outputNeuronIndex).ToList();
 
  124                         currentError += TrainingStats.ErrorCalculator.GetError(currentOutputSpikes, neuronFiringHistories[outputNeuronIndex].OutputSpikes);
 
  129                         TrainingStats.TrainingAlgorithm.UpdateWeights(
TrainingStats, errorDerivatives, previousErrorDerivatives);
 
  131                     else if (errorDerivatives != null)
 
  133                         if (errorDerivativesBatch == null)
 
  139                             foreach (var synapse 
in errorDerivatives.GetSynapses())
 
  141                                 errorDerivativesBatch.AddErrorDerivative(synapse, errorDerivatives.GetErrorDerivative(synapse));
 
  146                     previousErrorDerivatives = errorDerivatives;
 
  151                     TrainingStats.TrainingAlgorithm.UpdateWeights(
TrainingStats, errorDerivativesBatch, previousErrorDerivativesBatch);
 
  152                     previousErrorDerivativesBatch = errorDerivativesBatch;
 
  154                 nonFiringNeuronIndices = nonFiringNeuronIndices.Distinct().OrderBy(x => x).ToList();
 
  155                 TrainingStats.PerEpochStats.Add(TrainingStats.NumEpochs + epoch, 
new TrainingStatsPerEpoch(currentError, nonFiringNeuronIndices));
 
  156                 Console.WriteLine(
"\tError on training set " + trainingsetNumber + 
": " + currentError);
 
  158                 TrainingStats.NumEpochs += epoch;
 
  161             TrainingStats.NumEpochs += epoch;
 
  165         #region Helper Methods 
  174         private ErrorDerivativeParameters CalculateErrorDerivativeHelper(IReadOnlyDictionary<int, NeuronFiringHistory> neuronFiringHistories, List<Spike> desiredOutputSpikeTimes, IReadOnlyDictionary<Synapse, NeuronDerivativeParameters> neuronDerivativeParameters)
 
  178             #region Output Neurons 
  181                 if (desiredOutputSpikeTimes.Count(x => x.NeuronIndex == outputNeuronIndex) > 1)
 
  183                     throw new ArgumentException(
"MIMO Training Not Implemented");
 
  186                 if (neuronFiringHistories[outputNeuronIndex].OutputSpikes.Count == 0 || desiredOutputSpikeTimes.Count(x => x.NeuronIndex == outputNeuronIndex) == 0)
 
  188                     throw new ArgumentException(
"Actual or Desired Output Spike Times are Missing");
 
  191                 errorDerivativeParameters.ErrorToOutputSpikeTimeDerivatives.Add(
 
  193                     TrainingStats.ErrorCalculator.GetErrorToOutputSpikeTimeDerivative(
 
  194                         desiredOutputSpikeTimes.Where(x => x.NeuronIndex == outputNeuronIndex).ToList(),
 
  195                         neuronFiringHistories[outputNeuronIndex].OutputSpikes).First());
 
  198                 foreach (var synapse 
in TrainingStats.CurrentNetwork.GetNeuronsInputSynapses(outputNeuronIndex).Where(neuronDerivativeParameters.ContainsKey))
 
  200                     errorDerivativeParameters.OutputSpikeTimeToWeightDerivatives.Add(synapse, neuronDerivativeParameters[synapse].OutputSpikeTimeToWeightDerivative);
 
  205             #region Hidden Neurons 
  206             foreach (var layerIndex 
in Enumerable.Range(0, TrainingStats.CurrentNetwork.NumNeuronsPerLayer.Count - 1).Reverse()) 
 
  208                 foreach (var hiddenNeuronIndex 
in TrainingStats.CurrentNetwork.GetNeuronIndicesByLayer(layerIndex))
 
  210                     var outputSynapses = TrainingStats.CurrentNetwork.GetNeuronsOutputSynapses(hiddenNeuronIndex);
 
  211                     errorDerivativeParameters.ErrorToOutputSpikeTimeDerivatives.Add(hiddenNeuronIndex, 0);
 
  212                     foreach (var outputSynapse 
in outputSynapses)
 
  216                         errorDerivativeParameters.ErrorToOutputSpikeTimeDerivatives[hiddenNeuronIndex] +=
 
  217                             errorDerivativeParameters.ErrorToOutputSpikeTimeDerivatives[outputSynapse.OutputNeuronIndex] *
 
  218                             neuronDerivativeParameters[outputSynapse].OutputSpikeTimeToInputSpikeTimeDerivative;
 
  220                     foreach (var synapse 
in TrainingStats.CurrentNetwork.GetNeuronsInputSynapses(hiddenNeuronIndex))
 
  222                         errorDerivativeParameters.OutputSpikeTimeToWeightDerivatives.Add(synapse, neuronDerivativeParameters[synapse].OutputSpikeTimeToWeightDerivative);
 
  228             return errorDerivativeParameters;
 
SpikingNeuronNetwork CurrentNetwork
Gets or sets the current network after training 
 
Training Algorithm Interface 
 
TrainingMethod
Training Method Enum 
 
bool IsLayered
Indicates if the network is layered 
 
ErrorDerivativeParameters CalculateErrorDerivativeNumerical(Dictionary< int, NeuronFiringHistory > neuronFiringHistories, List< Spike > desiredOutputSpikeTimes)
Calculate the error derivative numerically 
 
TrainNetwork(SpikingNeuronNetwork spikingNeuronNetwork, List< SpikeSet > trainingSet, double learningRate, double maximumErrorAfterTraining, double maximumNumberOfTrainingEpochs, ISpikingError errorCalculator, ITrainingAlgorithm trainingAlgorithm, TrainingMethod trainingMethod)
Creates a new instance the TrainNetwork class 
 
Training Stats Per Epoch Class, inherits from a tuple representing the error and non-firing neuron in...
 
List< SpikeSet > TrainingSet
Gets or sets the training spike set, including input spikes and desired output spikes ...
 
TrainingMethod TrainingMethod
Gets or sets the training method 
 
Error Derivative Parameters Batch Class For Tracking Running Totals in Error Derivatives ...
 
List< int > GetOutputLayerNeuronIndices()
Gets a list of output layer neuron indices 
 
bool StartTraining()
Start training the network 
 
Training Statistics Class 
 
double MaximumErrorAfterTraining
Gets or sets the maximum error after training 
 
Error Derivative Parameters Class 
 
Spiking Neuron Network Class 
 
TrainNetwork(TrainingStats trainingStats)
Creates a new instance the TrainNetwork class 
 
ErrorDerivativeParameters CalculateErrorDerivative(Dictionary< int, NeuronFiringHistory > neuronFiringHistories, List< Spike > desiredOutputSpikeTimes)
Calculate the error derivative