-
Notifications
You must be signed in to change notification settings - Fork 21
Expand file tree
/
Copy pathtrain_network_dense.py
More file actions
31 lines (26 loc) · 1.31 KB
/
train_network_dense.py
File metadata and controls
31 lines (26 loc) · 1.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import tensorflow as tf
tf.set_random_seed(123)
import numpy as np
np.random.seed(123)
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/")
train_data_provider = mnist.train
validation_data_provider = mnist.validation
test_data_provider = mnist.test
from networks import network_dense
from configs import ConfigNetworkDense as config
# create a classifier
classifier = network_dense.FullyConnectedClassifier(input_size=config.input_size,
n_classes=config.n_classes,
layer_sizes=config.layer_sizes,
model_path=config.model_path,
dropout=config.dropout,
weight_decay=config.weight_decay,
activation_fn=config.activation_fn)
# than train it
classifier.fit(n_epochs=config.n_epochs,
batch_size=config.batch_size,
learning_rate_schedule=config.learning_rate_schedule,
train_data_provider=train_data_provider,
validation_data_provider=validation_data_provider,
test_data_provider=test_data_provider)