-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathutils.py
More file actions
67 lines (52 loc) · 1.31 KB
/
utils.py
File metadata and controls
67 lines (52 loc) · 1.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import pickle
import logging
import numpy as np
import torch.distributions as dist
def save(d, filename):
pickle.dump(d, open(filename, 'wb'))
def get_logger(name):
# setup logger
logging.basicConfig(level=logging.INFO)
return logging.getLogger(name)
def split_dataset(dataset, ratio=0.8):
"""Split dataset into train set and test set
Parameters
----------
dataset : list
List of data points
ratio : float
train/test split ratio
Returns
-------
tuple
train and test set lists
"""
inputs, outputs = dataset
n = len(inputs)
m = int(n * ratio)
return ( (inputs[:m], outputs[:m]), (inputs[m:], outputs[m:]) )
def shuffle(dataset):
"""Shuffle data points
Parameters
----------
dataset : tuple
( inputs, outputs ) tuple of list of inputs and outputs
Returns
-------
tuple
Tuple of shuffled lists of inputs and outputs
"""
inputs, outputs = dataset
shuffled_indices = np.arange(len(inputs))
np.random.shuffle(shuffled_indices)
return ( np.array(inputs, dtype='float32')[shuffled_indices],
np.array(outputs, dtype='float32')[shuffled_indices]
)
def bell(mu, sigma, x):
# create normal
norm = dist.Normal(mu, sigma)
# bell curve from pdf
psi = norm.log_prob(x).float().exp().numpy()
# normalize
psi = psi / psi.max()
return psi