Seitenanfang

Neuronales Netz in Go

Im ersten Teil haben wir uns mit der Zielsetzung und der Zerlegung der Eingangsdaten beschäftigt, der zweite Teil hat das neuronale Netz theoretisch beschrieben. Der dritte Teil besteht aus Quellcode: In rund 300 Zeilen wird das neuronale Netz erstellt, trainiert und getestet.

neunet.jpg

main.go

package main

import (
"bufio"
"fmt"
"log"
"os"
)

var Neurons [32]NeuronType

func main() {
Init()

// Create and init the neurons
for i, _ := range Neurons {
Neurons[i] = NewNeuron()
}

// Analyse a sample word
data := NewData("Autobahn")

// Show the results of the untrained neurons
for i, n := range Neurons {
log.Printf("Neuron %d: %v", i, n.Rate(data))
}

// Read garbage created using "pwgen" and a dictionary as training
trainFile("garbage.txt", false)
trainFile("de.txt", true)
trainFile("garbage.txt", false)
trainFile("de.txt", true)
trainFile("garbage.txt", false)
trainFile("de.txt", true)
trainFile("garbage.txt", false)
trainFile("de.txt", true)
trainFile("garbage.txt", false)
trainFile("de.txt", true)

// Rate the sample word again and show the results of the trained neurons
for i, n := range Neurons {
log.Printf("Neuron %d: %v", i, n.Rate(data))
}

// Rate some additional samples and show the overall (voting) result
for _, w := range []string{"Auto", "garbage", "as9Boo", "lae5uZee", "Be5Eirae", "igieng5I", "1.11.2011"} {
fmt.Printf("%v: %v\n", w, rate(w))
}

// Use the training files and show the success rate of each
var hit int
inFile, _ := os.Open("de.txt")
defer inFile.Close()
scanner := bufio.NewScanner(inFile)
scanner.Split(bufio.ScanLines)
for scanner.Scan() {
if rate(scanner.Text()) > 0 {
hit++
}
}
log.Print(hit)

hit = 0
inFile, _ = os.Open("garbage.txt")
defer inFile.Close()
scanner = bufio.NewScanner(inFile)
scanner.Split(bufio.ScanLines)
for scanner.Scan() {
if rate(scanner.Text()) < 0 {
hit++
}
}
log.Print(hit)
}

// Reads a file containing ham or spam and trains the net
func trainFile(filename string, expected bool) {
inFile, _ := os.Open(filename)
defer inFile.Close()
scanner := bufio.NewScanner(inFile)
scanner.Split(bufio.ScanLines)
for scanner.Scan() {
train(scanner.Text(), expected)
}
}

// Trains the network using one given word
func train(w string, expected bool) {
data := NewData(w) // Analyze

var bestNeutron int = -1
var bestResult float32
var votes int

// Let each neuron vote for or against the word
for i, n := range Neurons {
r := n.Rate(data)

if r > 0 {
votes++
} else if r < 0 {
votes--
}

// Remember the best-of-the-worse
if (r > 0 && !expected) || (r < 0 && expected) {
if expected {
if (bestResult == 0) || (r > bestResult) {
bestNeutron = i
bestResult = r
}
} else if !expected {
if (bestResult == 0) || (r < bestResult) {
bestNeutron = i
bestResult = r
}
}
}
}

if (votes > 0 && expected) || (votes < 0 && !expected) {
// Neurons decision was correct, nothing to do
return
}

// Neurons failed, train the best of the worst
if bestNeutron >= 0 {
// log.Printf("Best fail is %v", bestNeutron)
Neurons[bestNeutron].Train(data, expected)
}
}

// Rate a word and return the voting result
func rate(w string) int {
data := NewData(w)

var votes int

for _, n := range Neurons {
r := n.Rate(data)

if r > 0 {
votes++ // Looks like a word
} else if r < 0 {
votes-- // Looks like garbage
}
}

return votes
}

neuron.go

package main

import (
"fmt"
"math/rand"
"strings"
)

const Charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 ."

var CharCode = make(map[int32]int)

// Consonants/Vowels map
var isVowel = map[int32]bool{
65: true, // A
66: false, // B
67: false, // C
68: false, // D
69: true, // E
70: false, // F
71: false, // G
72: false, // H
73: true, // I
74: false, // J
75: false, // K
76: false, // L
77: false, // M
78: false, // N
79: true, // O
80: false, // P
81: false, // Q
82: false, // R
83: false, // S
84: false, // T
85: true, // U
86: false, // V
87: false, // W
88: false, // X
89: false, // Y
90: false, // Z
}

/*
Attributes:
0 - 19 Length (one item per possible length)
20 - 20 Consonants/Vowels rate (consonants / length)
21 - 40 Char at this position is non-char/number
41 - 1484 2-Char combination count (charset: A-Z, 0-9, " ", ".") = 38²
1485 - 2244 Char position bits (A1 ... A20, B1 ... B20, ..., "."1 ... "."20)
*/
type NeuronType struct {
Weights [2244]float32 // Weight per attribute
MinWeight float32 // Minimum weight sum to return true
}

// Input data
type NeuronDataType struct {
Word string // as cleartext
Attrs [2244]float32 // as attributes
}

// Create new neuron
func NewNeuron() (n NeuronType) {
for i, _ := range n.Weights {
// Init each weight with a random value between -100 and +100
n.Weights[i] = (float32(rand.Intn(2010)) - 1000) / 10
}

return
}

func Init() {
// Convert charset for simple access
for i, c := range Charset {
CharCode[c] = i
}
}

// Parse input data
func NewData(w string) (d NeuronDataType) {
d.Word = strings.ToUpper(w)

wordlen := len(d.Word)
if wordlen > 20 {
panic(fmt.Sprintf("Input word \"%s\" is too long: %v chars", w, wordlen))
}

// Set attributes for length of word
for i := 0; i < 20; i++ {
if wordlen == (i + 1) {
d.Attrs[i] = 1
} else {
d.Attrs[i] = 0
}
}

// Counters for consonants and vowels
cons := 0
vowels := 0

// Walk though word
for i, c := range d.Word {
// Increase consonants/vowels counters
isV, ok := isVowel[c]
if isV {
vowels++
} else if ok {
cons++
} else if c < 48 || c > 57 { // Non number
d.Attrs[i+21] = 1
}

// Set char/position attribute
if _, ok = CharCode[c]; ok {
d.Attrs[1485+(CharCode[c]*20)+i] = 1
}

// Count char pairs
if i > 0 {
c1, ok1 := CharCode[int32(d.Word[i-1])]
c2, ok2 := CharCode[c]
if ok1 && ok2 {
d.Attrs[41+c1*len(Charset)+c2]++
}
}
}

// Calculate consonants/vowels ratio
if (cons + vowels) > 0 {
d.Attrs[20] = float32(cons) / float32(cons+vowels)
}

return
}

// Rate a word using the current weights and return the overall result
func (n *NeuronType) Rate(d NeuronDataType) (weightSum float32) {
weightSum = 0 // Init

for i, v := range d.Attrs {
weightSum += v * n.Weights[i]
}

return
}

// Train Neuron
func (n *NeuronType) Train(d NeuronDataType, isValid bool) {
for i, v := range d.Attrs {
r := v * n.Weights[i]
// Change weight on error
// false-negative: r is <0, -= r will increase weight
// false-positive: r is >0, -= r will decrease weight
if (r < 0 && isValid) || (r > 0 && !isValid) {
n.Weights[i] -= r / 10
}
}

return
}

Ergebnis

Nach dem Training arbeitet das neuronale Netz erstaunlich gut. Die Testdaten werden bis auf einen Ausreißer richtig eingeordnet:

Auto: 12
garbage: 10
as9Boo: -6
lae5uZee: -14
Be5Eirae: -4
igieng5I: 2
1.11.2011: -12

Lediglich beim vorletzten Beispiel stimmt eine kleine Mehrheit für ein echtes Wort. Davor wird sogar das englische Wort garbage richtig erkannt, obwohl das Netz nur mit deutschen Wörtern trainiert wurde.

Bei der Erkennung der Trainingsdaten liegt die KI fast immer richtig: 60.167 von 60.840 Wörtern werden als solche erkannt und 95.677 von 100.000 von pwgen erzeugte Zufallswerte als Müll. Das entspricht 1,1% False-Positives und 4,3% False-Negatives. Beim zweiten Wert ist allerdings nicht ganz auszuschließen, dass die generierten Daten auch tatsächlich echte Wörter enthalten.

Für rund einen Tag Arbeit bin ich mit dem Ergebnis zufrieden.

Ist das wirklich ein neuronales Netz?

Vermutlich nicht. Ein echtes neuronales Netz zu definieren kostet erfahrene Profis Wochen, manchmal Monate. Die eigentliche Trainingsphase läuft üblicherweise nochmal einige Wochen auf Grafikkarten, deren GPUs auf Vektormultiplikationen optimiert sind. Mein Programm läuft auf meinem fünf Jahre alten Laptop rund drei Minuten.

Das Projekt verfolgt die Ansätze eines neuronalen Netzes. Es lernt auf Basis existierender Daten und erkennt selbst, welche Faktoren am Ende relevant sind und welche nicht. Für mich war es eine schöne Denkaufgabe, unter Verwendung der Podcast - Informationen selbst eine Lösung zu entwickeln und eine nette Fingerübung in Go. Ich denke, dass die Grundsätze im Prinzip abgebildet sind.

Was haltet Ihr von meiner Lösung? Hinterlasst gerne einen Kommentar mit Eurer Meinung.

 

Noch keine Kommentare. Schreib was dazu

Schreib was dazu

Die folgenden HTML-Tags sind erlaubt:<a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <strike> <strong>