-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathvisualize.py
More file actions
56 lines (39 loc) · 1.76 KB
/
visualize.py
File metadata and controls
56 lines (39 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from __future__ import print_function
import os
import yaml
import argparse
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from data import DataLoader
from model import get_model
def scatter(x, labels, config):
palette = np.array(sns.color_palette("hls", config["data"]["num_classes"]))
plt.switch_backend('agg')
fig, ax = plt.subplots()
ax.scatter(x[:,0], x[:,1], lw=0, s=40, alpha=0.2, c=palette[labels.astype(np.int)])
for idx in range(config["data"]["num_classes"]):
xtext, ytext = np.median(x[labels == idx, :], axis=0)
txt = ax.text(xtext, ytext, str(idx), fontsize=20)
plt.title("{} T-SNE".format(config["run-title"]))
plt.savefig(os.path.join(config["paths"]["save"], "tsne.png"))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Model Paramaters')
parser.add_argument('-c', '--config', type=str, default="config.yaml", help='path of config file')
args = parser.parse_args()
with open(args.config, 'r') as file:
config = yaml.load(file)
paths = config["paths"]
data = config["data"]
dataloader = DataLoader(config)
dataloader.load()
input_shape = (data["imsize"], data["imsize"], data["imchannel"])
model = get_model(input_shape, config, top=False)
model.load_weights(paths["load"], by_name=True)
X_batch, y_batch = dataloader.get_random_batch(k = -1)
#embeddings = X_batch.reshape(-1, 784)
embeddings = model.predict(X_batch, batch_size=config["train"]["batch-size"], verbose=1)
tsne = TSNE(n_components=2, perplexity=config["tsne"]["perplexity"], verbose=1, n_iter=config["tsne"]["n_iter"])
tsne_embeds = tsne.fit_transform(embeddings)
scatter(tsne_embeds, y_batch, config)