-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathtrain_process.go
More file actions
103 lines (88 loc) · 2.37 KB
/
train_process.go
File metadata and controls
103 lines (88 loc) · 2.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
package cnns
import (
"fmt"
"log"
"math/rand"
"time"
"gonum.org/v1/gonum/mat"
)
// Train Train neural network
/*
inputs - input data for training
desired - target outputs for input
testData - input data for doing tests
testDesired - target outputs for testing
epochsNum - number of epochs
*/
func (n *WholeNet) Train(inputs []*mat.Dense, desired []*mat.Dense, testData []*mat.Dense, testDesired []*mat.Dense, epochsNum int) (float64, float64, error) {
var err error
trainError := 0.0
testError := 0.0
if len(inputs) != len(desired) {
return trainError, testError, fmt.Errorf("number of inputs not equal to number of desired")
}
if len(testData) != len(testDesired) {
return trainError, testError, fmt.Errorf("number of inputs for test not equal to number of desired for test")
}
// Initial shuffling of input data
for i := range inputs {
j := rand.Intn(i + 1)
inputs[i], inputs[j] = inputs[j], inputs[i]
desired[i], desired[j] = desired[j], desired[i]
}
start := time.Now()
for e := 0; e < epochsNum; e++ {
// Shuffle training data every epoch
for i := range inputs {
j := rand.Intn(i + 1)
inputs[i], inputs[j] = inputs[j], inputs[i]
desired[i], desired[j] = desired[j], desired[i]
}
st := time.Now()
for i := range inputs {
in := inputs[i]
err := n.FeedForward(in)
if err != nil {
log.Printf("Feedforward caused error: %s", err.Error())
return 0.0, 0.0, err
}
target := desired[i]
err = n.Backpropagate(target)
if err != nil {
log.Printf("Backpropagate caused error: %s", err.Error())
return 0.0, 0.0, err
}
}
log.Printf("Epoch #%v done in %v", e, time.Since(st))
}
log.Printf("Training %v epochs done in %v", epochsNum, time.Since(start))
fmt.Println("Evaluating errors...")
for i := range inputs {
in := inputs[i]
target := desired[i]
err := n.FeedForward(in)
if err != nil {
log.Printf("Feedforward (testing) caused error: %s", err.Error())
return 0.0, 0.0, err
}
out := n.GetOutput()
loss := mse(target, out)
trainError += loss
}
for i := range testData {
in := testData[i]
target := testDesired[i]
n.FeedForward(in)
out := n.GetOutput()
loss := mse(target, out)
testError += loss
}
return trainError, testError, err
}
func mse(t1, t2 *mat.Dense) float64 {
tmp := &mat.Dense{}
tmp.Sub(t1, t2)
tmpPow := &mat.Dense{}
tmpPow.MulElem(tmp, tmp)
return mat.Sum(tmpPow)
}