|
| 1 | +# src/neuromorphic_analytics/model.py |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import logging |
| 5 | + |
| 6 | +# Set up logging for the Spiking Neural Network Model |
| 7 | +logger = logging.getLogger(__name__) |
| 8 | + |
| 9 | +class SpikingNeuralNetworkModel: |
| 10 | + def __init__(self, num_neurons, threshold=1.0, decay=0.9): |
| 11 | + """ |
| 12 | + Initialize the Spiking Neural Network Model. |
| 13 | +
|
| 14 | + Parameters: |
| 15 | + - num_neurons (int): Number of neurons in the network. |
| 16 | + - threshold (float): The threshold for neuron firing. |
| 17 | + - decay (float): The decay factor for the neuron's membrane potential. |
| 18 | + """ |
| 19 | + self.num_neurons = num_neurons |
| 20 | + self.threshold = threshold |
| 21 | + self.decay = decay |
| 22 | + self.membrane_potential = np.zeros(num_neurons) # Initialize membrane potential |
| 23 | + logger.info("Spiking Neural Network Model initialized with %d neurons.", num_neurons) |
| 24 | + |
| 25 | + def predict(self, input_data): |
| 26 | + """ |
| 27 | + Make a prediction based on the input data. |
| 28 | +
|
| 29 | + Parameters: |
| 30 | + - input_data (list or np.ndarray): The input data for prediction. |
| 31 | +
|
| 32 | + Returns: |
| 33 | + - list: The output spikes from the network. |
| 34 | + """ |
| 35 | + logger.info("Making prediction with input data: %s", input_data) |
| 36 | + input_data = np.array(input_data) |
| 37 | + self.membrane_potential += input_data # Update membrane potential with input |
| 38 | + |
| 39 | + # Check for spikes |
| 40 | + spikes = self.membrane_potential >= self.threshold |
| 41 | + self.membrane_potential[spikes] = 0 # Reset potential for neurons that spiked |
| 42 | + |
| 43 | + # Apply decay to the membrane potential |
| 44 | + self.membrane_potential *= self.decay |
| 45 | + |
| 46 | + logger.debug("Membrane potential after prediction: %s", self.membrane_potential) |
| 47 | + return spikes.astype(int).tolist() # Return spikes as a list of 0s and 1s |
| 48 | + |
| 49 | + def evaluate(self, test_data, true_labels): |
| 50 | + """ |
| 51 | + Evaluate the model's performance on test data. |
| 52 | +
|
| 53 | + Parameters: |
| 54 | + - test_data (list of lists or np.ndarray): The data to test the model on. |
| 55 | + - true_labels (list): The true labels for the test data. |
| 56 | +
|
| 57 | + Returns: |
| 58 | + - float: The accuracy of the model on the test data. |
| 59 | + """ |
| 60 | + correct_predictions = 0 |
| 61 | + total_predictions = len(test_data) |
| 62 | + |
| 63 | + for data, true_label in zip(test_data, true_labels): |
| 64 | + prediction = self.predict(data) |
| 65 | + if prediction == true_label: |
| 66 | + correct_predictions += 1 |
| 67 | + |
| 68 | + accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0 |
| 69 | + logger.info("Model evaluation completed. Accuracy: %.2f", accuracy) |
| 70 | + return accuracy |
0 commit comments