-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmcintpdiff_wgan.py
More file actions
97 lines (82 loc) · 2.97 KB
/
mcintpdiff_wgan.py
File metadata and controls
97 lines (82 loc) · 2.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from __future__ import division
import numpy as np
import wgan_gp as wgp
#import test_functions as te_funs
import scipy.io as sio
#import matplotlib.pyplot as plt
#from mpl_toolkits.mplot3d import Axes3D
# load fem_dof from matlab data generated by IFISS
xi_path = '/home/tkj/rap_prog/pyfile/t3f_UQ/data/pdiffxi_2_4_289_20000.mat'
U_path = '/home/tkj/rap_prog/pyfile/t3f_UQ/data/pdiff_2_4_289_20000.mat'
mean_path = '/home/tkj/rap_prog/pyfile/t3f_UQ/data/pdiff_2_4_289_mean.mat'
var_path = '/home/tkj/rap_prog/pyfile/t3f_UQ/data/pdiff_2_4_289_var.mat'
#xi_path = '/root/tkj/data/pdiffxi_2_4_289_20000.mat'
#U_path = '/root/tkj/data//pdiff_2_4_289_20000.mat'
#mean_path = '/root/tkj/data//pdiff_2_4_289_mean.mat'
#var_path = '/root/tkj/data//pdiff_2_4_289_var.mat'
# train data
x = sio.loadmat(xi_path)['XI']
data = sio.loadmat(U_path)['U'].T # each row is an sample for WGAN model, and each column is an sample in original data
x = x[:4000]
data = data[:4000]
# groundtruth data
x_mean = sio.loadmat(mean_path)['um'].T
x_var = sio.loadmat(var_path)['uv'].T
lam=10
num_hidden=512
batch_size=400
num_epochs=20000
lr_rateg=1e-4
lr_rated=1e-4
lr_decay=1.0
to_restore=False
output_path='ForwardUQ'
net_type = 'conv'
model = wgp.Gan(x, data, lam, num_hidden, batch_size, num_epochs, lr_rateg, lr_rated, lr_decay, to_restore, output_path, net_type)
model.train()
disc_loss = model.disc_loss_history
np.save('disc_loss.npy', disc_loss)
singular_values = np.load('singular_values.npy')
weightmat_ranks = np.load('weightmat_ranks.npy')
#plt.figure(101)
#plt.subplot(511)
#plt.semilogy(singular_values[0])
#plt.subplot(512)
#plt.semilogy(singular_values[1])
#plt.subplot(513)
#plt.semilogy(singular_values[2])
#plt.subplot(514)
#plt.semilogy(singular_values[3])
#plt.subplot(515)
#plt.semilogy(singular_values[4])
#plt.figure(201)
#plt.plot(disc_loss)
batch_times = [20, 40, 60, 80, 100, 120, 140, 160]
errorm_list = []
errorv_list = []
#terrorm_list = []
#terrorv_list = []
for k in range(len(batch_times)):
mcf_samples = model.generate_sample(batch_times[k])
errorm_mcf = np.linalg.norm(np.mean(mcf_samples,0)-x_mean)/np.linalg.norm(x_mean)
errorv_mcf = np.linalg.norm(np.var(mcf_samples,0)-x_var)/np.linalg.norm(x_var)
errorm_list.append(errorm_mcf)
errorv_list.append(errorv_mcf)
np.save('mean_error_f.npy', errorm_list)
np.save('var_error_f.npy', errorv_list)
np.save('mcf_samples.npy', mcf_samples)
#nodes = np.linspace(-1, 1, 17)
#Y_nodes, X_nodes = np.meshgrid(nodes, nodes) # FEM mesh structure
#for i in range(5):
# f_value = np.reshape(mcf_samples[i,:], (17,17), order='F')
# fig = plt.figure(i+1)
# ax = Axes3D(fig)
# ax.plot_surface(X_nodes, Y_nodes, f_value, rstride=1, cstride=1, cmap=plt.cm.coolwarm)
# ax.set_xlabel('x label', color='r')
# ax.set_ylabel('y label', color='g')
# ax.set_zlabel('z label', color='b')
#plt.figure(6)
#plt.plot(np.array(batch_times)*batch_size, errorm_list, '-ro')
#plt.figure(7)
#plt.plot(np.array(batch_times)*batch_size, errorv_list,'-ro')
#plt.show()