This repository was archived by the owner on Dec 24, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsom.py
More file actions
115 lines (88 loc) · 3.6 KB
/
som.py
File metadata and controls
115 lines (88 loc) · 3.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from random import choice as random_choice
from argparse import ArgumentParser
import numpy as np
import tensorflow as tf
import cv2
from clustering.som import SelfOrganizingMap
def main():
parser = ArgumentParser()
parser.add_argument('file', metavar='path', help='The TFRecord file to load.')
args = parser.parse_args()
random_seed = None
width = 128
height = 128
dimensions = 32*32*3
total_iterations = 1000
np.random.seed(random_seed)
with tf.Graph().as_default() as graph:
with tf.variable_scope('input'):
# TODO: glob the actual TFRecord files
queue = tf.train.string_input_producer([args.file])
options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
reader = tf.TFRecordReader(options=options)
_, example = reader.read(queue)
features = tf.parse_single_example(example, features={'raw': tf.FixedLenFeature([], tf.string)})
image = tf.decode_raw(features['raw'], tf.uint8)
raw_image = tf.reshape(image, (32, 32, 3))
with tf.variable_scope('som'):
som = SelfOrganizingMap(height, width, dimensions, total_iterations, random_seed=random_seed)
coord = tf.train.Coordinator()
with tf.Session(graph=graph) as sess:
# initialize all variables
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init)
# fire up the input queue runners
threads = tf.train.start_queue_runners(coord=coord)
# work with real data ...
data = [(1., 0., 0.),
(0., 1., 0.),
(0., 0., 1.),
(1., 1., 0.),
(0., 1., 1.),
(1., 0., 1.),
(1., 1., 1.),
(0., 0., 0.),
(.5, .5, .5)]
# ... or work with synthetic data
#data = np.random.rand(512, 3)
window_title = 'Grid'
cv2.namedWindow(window_title, cv2.WINDOW_NORMAL)
# prepare the output buffer
y_stride = 2
x_stride = 2
image = np.empty((32 * height // y_stride, 32 * width // x_stride, 3))
canceled = False
while not coord.should_stop():
for _ in range(0, 100):
if not coord.should_stop():
# fetch one example
target = sess.run(raw_image)
target = np.reshape(target, newshape=(dimensions,)) / 255.
# run one iteration
iteration = som.run(target, sess)
if iteration >= total_iterations:
coord.should_stop()
continue
# reporting every couple of iterations
if iteration % 10 == 0:
print('Displaying iteration %i ...' % iteration)
grid = som.get_grid(sess)
for y in range(0, height, y_stride):
yy = np.arange(y*32//y_stride, y*32//y_stride+32)
for x in range(0, width, x_stride):
grid_image = np.reshape(grid[y, x], (32, 32, 3))
xx = x*32//x_stride
image[yy, xx:xx+32, :] = grid_image
cv2.imshow(window_title, image)
if cv2.waitKey(10) & 0xff == 27:
canceled = True
break
coord.request_stop()
coord.join(threads)
if not canceled:
print('Done.')
cv2.waitKey(0)
cv2.destroyAllWindows()
pass
if __name__ == '__main__':
main()