summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py44
1 files changed, 37 insertions, 7 deletions
diff --git a/train.py b/train.py
index ab973a1..7857f3f 100755
--- a/train.py
+++ b/train.py
@@ -5,16 +5,18 @@ import numpy
import sys
import importlib
+import theano
+from theano import tensor
+
from blocks.dump import load_parameter_values
from blocks.dump import MainLoopDumpManager
-from blocks.extensions import Printing
+from blocks.extensions import Printing, SimpleExtension
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
from blocks.extensions.plot import Plot
from blocks.graph import ComputationGraph
from blocks.main_loop import MainLoop
from blocks.model import Model
from blocks.algorithms import GradientDescent
-from theano import tensor
import datastream
# from apply_model import Apply
@@ -30,6 +32,29 @@ if __name__ == "__main__":
config = importlib.import_module('%s' % model_name)
+class GenText(SimpleExtension):
+ def __init__(self, model, init_text, max_bytes, **kwargs):
+ self.init_text = init_text
+ self.max_bytes = max_bytes
+
+ cg = ComputationGraph([model.pred])
+ assert(len(cg.inputs) == 1)
+ assert(cg.inputs[0].name == 'bytes')
+ self.f = theano.function(inputs=cg.inputs, outputs=[model.pred])
+
+ super(GenText, self).__init__(**kwargs)
+
+ def do(self, which_callback, *args):
+ v = numpy.array([ord(i) for i in self.init_text],
+ dtype='int16')[None, :].repeat(axis=0, repeats=config.num_seqs)
+
+ while v.shape[1] < self.max_bytes:
+ pred, = self.f(v)
+ v = numpy.concatenate([v, pred[:, -1:]], axis=1)
+
+ for i in range(v.shape[0]):
+ print "Sample:", ''.join([chr(int(v[i, j])) for j in range(v.shape[1])])
+
def train_model(m, train_stream, load_location=None, save_location=None):
# Define the model
@@ -44,6 +69,9 @@ def train_model(m, train_stream, load_location=None, save_location=None):
algorithm = GradientDescent(cost=m.cost_reg,
step_rule=config.step_rule,
params=cg.parameters)
+
+ algorithm.add_updates(m.updates)
+
main_loop = MainLoop(
model=model,
data_stream=train_stream,
@@ -51,12 +79,13 @@ def train_model(m, train_stream, load_location=None, save_location=None):
extensions=[
TrainingDataMonitoring(
[m.cost_reg, m.error_rate_reg, m.cost, m.error_rate],
- prefix='train', every_n_epochs=1*config.pt_freq),
- Printing(every_n_epochs=1*config.pt_freq, after_epoch=False),
+ prefix='train', every_n_epochs=1),
+ Printing(every_n_epochs=1, after_epoch=False),
Plot(document='tr_'+model_name+'_'+config.param_desc,
channels=[['train_cost', 'train_cost_reg'],
['train_error_rate', 'train_error_rate_reg']],
- every_n_epochs=1*config.pt_freq, after_epoch=False)
+ every_n_epochs=1, after_epoch=False),
+ GenText(m, '\t', 20, every_n_epochs=1, after_epoch=False)
]
)
main_loop.run()
@@ -72,8 +101,9 @@ def train_model(m, train_stream, load_location=None, save_location=None):
if __name__ == "__main__":
# Build datastream
train_stream = datastream.setup_datastream('data/logcompil.txt',
- config.chars_per_seq,
- config.seqs_per_epoch)
+ config.num_seqs,
+ config.seq_len,
+ config.seq_div_size)
# Build model
m = config.Model()