.. battelle.nn .. currentmodule:: battelle.nn Neural Network subpackage (battelle.nn) ======================================== .. autosummary:: :toctree: generated/ Neuron Layer NeuralNetwork ActivationFunction .. automethod:: NeuralNetwork.train NeuralNetwork.predict NeuralNetwork.get_metrics This module allows the user to create, train and visualize simple neural networks. Here is a simple example of binary classification: .. code-block:: python import numpy as np import matplotlib.pyplot as plt from battelle.nn import * # List of features x = np.random.random(size=(300, 2)) - 0.5 x = [list(xi) for xi in x] # List of corresponding labels (xor) y = np.array([np.sign(np.prod(k)) for k in x]) # Define neural network nn = NeuralNetwork() nn.add_layer(Layer(5, "tanh", input_dim=2)) nn.add_layer(Layer(4, "tanh")) nn.add_layer(Layer(1, "tanh")) # Train the neural network in epochs NUM_EPOCHS = 100 for i in range(NUM_EPOCHS): print(f"Epoch {i+1}/{NUM_EPOCHS}") nn.train(x, y, learning_rate=0.03) print(f"Loss: {np.mean([(nn.predict(x[k]) - y[k])**2 for k in range(len(x))])}") # Plot training data plt.subplot(1, 2, 1) for i, xi in enumerate(x): if y[i] == 1: plt.plot(xi[0], xi[1], "xr") else: plt.plot(xi[0], xi[1], "xb") plt.title("Testing data") # Plot neural net result plt.subplot(1, 2, 2) xrange = np.linspace(-0.5, 0.5, 20) yrange = np.linspace(-0.5, 0.5, 20) for x in xrange: for y in yrange: if nn.predict([x, y]) > 0: plt.plot(x, y, "xr") else: plt.plot(x, y, "xb") plt.title("Neural net result") plt.show() This code creates the following graph: .. figure:: images/xor_nn.png