summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-06-10 15:22:18 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-06-10 15:22:18 -0400
commit9be0db7523abdfa59c19115585f1ee96d73d08c6 (patch)
treecabc1b8db78ad56fde1fa6a3bc2760a2883d4d5c /train.py
parentc5e1cd9c8c896096ad1630909a655b06eb398abb (diff)
downloadtext-rnn-9be0db7523abdfa59c19115585f1ee96d73d08c6.tar.gz
text-rnn-9be0db7523abdfa59c19115585f1ee96d73d08c6.zip
Changes
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py64
1 files changed, 39 insertions, 25 deletions
diff --git a/train.py b/train.py
index 7857f3f..a8e9ef2 100755
--- a/train.py
+++ b/train.py
@@ -5,14 +5,17 @@ import numpy
import sys
import importlib
+from contextlib import closing
+
import theano
from theano import tensor
+from theano.tensor.shared_randomstreams import RandomStreams
-from blocks.dump import load_parameter_values
-from blocks.dump import MainLoopDumpManager
+from blocks.serialization import load_parameter_values, secure_dump, BRICK_DELIMITER
from blocks.extensions import Printing, SimpleExtension
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
-from blocks.extensions.plot import Plot
+from blocks.extras.extensions.plot import Plot
+from blocks.extensions.saveload import Checkpoint, Load
from blocks.graph import ComputationGraph
from blocks.main_loop import MainLoop
from blocks.model import Model
@@ -37,10 +40,14 @@ class GenText(SimpleExtension):
self.init_text = init_text
self.max_bytes = max_bytes
- cg = ComputationGraph([model.pred])
+
+ out = model.out[:, -1, :] / numpy.float32(config.sample_temperature)
+ prob = tensor.nnet.softmax(out)
+
+ cg = ComputationGraph([prob])
assert(len(cg.inputs) == 1)
assert(cg.inputs[0].name == 'bytes')
- self.f = theano.function(inputs=cg.inputs, outputs=[model.pred])
+ self.f = theano.function(inputs=cg.inputs, outputs=[prob])
super(GenText, self).__init__(**kwargs)
@@ -49,22 +56,21 @@ class GenText(SimpleExtension):
dtype='int16')[None, :].repeat(axis=0, repeats=config.num_seqs)
while v.shape[1] < self.max_bytes:
- pred, = self.f(v)
- v = numpy.concatenate([v, pred[:, -1:]], axis=1)
+ prob, = self.f(v)
+ prob = prob / 1.00001
+ pred = numpy.zeros((prob.shape[0],), dtype='int16')
+ for i in range(prob.shape[0]):
+ pred[i] = numpy.random.multinomial(1, prob[i, :]).nonzero()[0][0]
+ v = numpy.concatenate([v, pred[:, None]], axis=1)
for i in range(v.shape[0]):
print "Sample:", ''.join([chr(int(v[i, j])) for j in range(v.shape[1])])
-def train_model(m, train_stream, load_location=None, save_location=None):
+def train_model(m, train_stream, dump_path=None):
# Define the model
model = Model(m.cost)
- # Load the parameters from a dumped model
- if load_location is not None:
- logger.info('Loading parameters...')
- model.set_param_values(load_parameter_values(load_location))
-
cg = ComputationGraph(m.cost_reg)
algorithm = GradientDescent(cost=m.cost_reg,
step_rule=config.step_rule,
@@ -72,11 +78,26 @@ def train_model(m, train_stream, load_location=None, save_location=None):
algorithm.add_updates(m.updates)
+ # Load the parameters from a dumped model
+ if dump_path is not None:
+ try:
+ logger.info('Loading parameters...')
+ with closing(numpy.load(dump_path)) as source:
+ param_values = {'/' + name.replace(BRICK_DELIMITER, '/'): source[name]
+ for name in source.keys()
+ if name != 'pkl' and not 'None' in name}
+ model.set_param_values(param_values)
+ except IOError:
+ pass
+
main_loop = MainLoop(
model=model,
data_stream=train_stream,
algorithm=algorithm,
extensions=[
+ Checkpoint(path=dump_path,
+ after_epoch=False, every_n_epochs=config.save_freq),
+
TrainingDataMonitoring(
[m.cost_reg, m.error_rate_reg, m.cost, m.error_rate],
prefix='train', every_n_epochs=1),
@@ -84,19 +105,14 @@ def train_model(m, train_stream, load_location=None, save_location=None):
Plot(document='tr_'+model_name+'_'+config.param_desc,
channels=[['train_cost', 'train_cost_reg'],
['train_error_rate', 'train_error_rate_reg']],
+ server_url='http://eos21:4201/',
every_n_epochs=1, after_epoch=False),
- GenText(m, '\t', 20, every_n_epochs=1, after_epoch=False)
+
+ GenText(m, ' ', config.sample_len, every_n_epochs=1, after_epoch=False)
]
)
main_loop.run()
- # Save the main loop
- if save_location is not None:
- logger.info('Saving the main loop...')
- dump_manager = MainLoopDumpManager(save_location)
- dump_manager.dump(main_loop)
- logger.info('Saved')
-
if __name__ == "__main__":
# Build datastream
@@ -114,8 +130,6 @@ if __name__ == "__main__":
m.pred.name = 'pred'
# Train the model
- saveloc = 'model_data/%s' % model_name
- train_model(m, train_stream,
- load_location=None,
- save_location=None)
+ saveloc = 'model_data/%s-%s' % (model_name, config.param_desc)
+ train_model(m, train_stream, dump_path=saveloc)