-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLSTM_network.py
More file actions
42 lines (38 loc) · 1.48 KB
/
Copy pathLSTM_network.py
File metadata and controls
42 lines (38 loc) · 1.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#long short term memory neural network
from __future__ import absolute_import,division, print_function
import os
from six import moves
import ssl
import tflearn
from tflearn.data_utils import *
#step 1- retrieve the data
path="US_cities.txt"
if not os.path.isfile(path):
context = ssl._create_unverified_context()
#get dataset
moves.urllib.request.urlretrieve("https://raw.githubusercontent.com/tflearn/tflearn.github.io/master/resources/US_Cities.txt", path, context=context)
maxlen=20
#vectorize the text file
X, Y, char_idx= \
textfile_to_semi_redundant_sequences(path,seq_maxlen=maxlen,redun_step=3)
#create LSTM
g = tflearn.input_data(shape=[None,maxlen,len(char_idx)])
g = tflearn.lstm(g,512,return_seq=True)
g = tflearn.dropout(g,0.5)
g = tflearn.lstm(g,512)
g = tflearn.dropout(g,0.5)
g = tflearn.fully_connected(g, len(char_idx), activation='softmax')
g = tflearn.regression(g, optimizer='adam', loss='categorical_crossentropy', learning_rate=0.001)
#generate cities
m = tflearn.SequenceGenerator(g, dictionary=char_idx,seq_maxlen=maxlen,
clip_gradients=5.0,
checkpoint_path='model_us_cities')
for i in range(40):
seed=random_sequence_from_textfile(path, maxlen)
m.fit(X,Y, validation_set=0.1, batch_size=128, n_epoch=1, run_id='us_cities')
print("testing with temp 1.2")
print(m.generate(30,temperature=1.2, seq_seed=seed))
print("testing with temp 1.0")
print(m.generate(30,temperature=1.0, seq_seed=seed))
print("testing with temp 0.5")
print(m.generate(30,temperature=0.5, seq_seed=seed))