summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore3
-rw-r--r--__init__.py0
-rw-r--r--datastream.py76
-rw-r--r--lstm.py76
-rwxr-xr-xtrain.py91
5 files changed, 246 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..cec3f1e
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,3 @@
+*.pyc
+*.swp
+data/*
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/__init__.py
diff --git a/datastream.py b/datastream.py
new file mode 100644
index 0000000..5d9441f
--- /dev/null
+++ b/datastream.py
@@ -0,0 +1,76 @@
+import logging
+import random
+import numpy
+
+import cPickle
+
+from picklable_itertools import iter_
+
+from fuel.datasets import Dataset
+from fuel.streams import DataStream
+from fuel.schemes import IterationScheme
+from fuel.transformers import Transformer
+
+import sys
+import os
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger(__name__)
+
+
+class BinaryFileDataset(Dataset):
+ def __init__(self, filename, **kwargs):
+ self.provides_sources= ('bytes',)
+
+ self.f = open(filename, "rb")
+
+ super(BinaryFileDataset, self).__init__(**kwargs)
+
+ def get_data(self, state=None, request=None):
+ if request is None:
+ raise ValueError("Expected a request: begin, length")
+
+ bg, ln = request
+ self.f.seek(bg)
+ return (self.f.read(ln),)
+
+ def num_examples(self):
+ return os.fstat(self.f.fileno()).st_size
+
+class RandomBlockIterator(IterationScheme):
+ def __init__(self, item_range, seq_len, num_seqs_per_epoch, **kwargs):
+ self.seq_len = seq_len
+ self.num_seqs = num_seqs_per_epoch
+ self.item_range = item_range
+
+ super(RandomBlockIterator, self).__init__(**kwargs)
+
+ def get_request_iterator(self):
+ l = [(random.randrange(0, self.item_range - self.seq_len + 1), self.seq_len)
+ for _ in xrange(self.num_seqs)]
+ return iter_(l)
+
+class BytesToIndices(Transformer):
+ def __init__(self, stream, **kwargs):
+ self.sources = ('bytes',)
+ super(BytesToIndices, self).__init__(stream, **kwargs)
+
+ def get_data(self, request=None):
+ if request is not None:
+ raise ValueError('Unsupported: request')
+ data = next(self.child_epoch_iterator)
+ return numpy.array([ord(i) for i in data[0]], dtype='int16'),
+
+def setup_datastream(filename, seq_len, num_seqs_per_epoch=100):
+ ds = BinaryFileDataset(filename)
+ it = RandomBlockIterator(ds.num_examples(), seq_len, num_seqs_per_epoch)
+ stream = DataStream(ds, iteration_scheme=it)
+ stream = BytesToIndices(stream)
+
+ return stream
+
+if __name__ == "__main__":
+ # Test
+ stream = setup_datastream("data/logcompil.txt", 100)
+ print(next(stream.get_epoch_iterator()))
+
diff --git a/lstm.py b/lstm.py
new file mode 100644
index 0000000..72b67f1
--- /dev/null
+++ b/lstm.py
@@ -0,0 +1,76 @@
+import theano
+from theano import tensor
+
+from blocks.algorithms import Momentum, AdaDelta
+from blocks.bricks import Tanh, Softmax, Linear, MLP
+from blocks.bricks.recurrent import LSTM
+from blocks.initialization import IsotropicGaussian, Constant
+
+from blocks.filter import VariableFilter
+from blocks.roles import WEIGHT
+from blocks.graph import ComputationGraph, apply_noise
+
+chars_per_seq = 100
+seqs_per_epoch = 1
+
+io_dim = 256
+
+hidden_dims = [200, 500]
+activation_function = Tanh()
+
+w_noise_std = 0.01
+
+step_rule = AdaDelta()
+
+pt_freq = 1
+
+param_desc = '' # todo
+
+class Model():
+ def __init__(self):
+ inp = tensor.lvector('bytes')
+
+ in_onehot = tensor.eq(tensor.arange(io_dim, dtype='int16').reshape((1, io_dim)),
+ inp[:, None])
+
+ dims = [io_dim] + hidden_dims
+ prev = in_onehot[None, :, :]
+ bricks = []
+ for i in xrange(1, len(dims)):
+ linear = Linear(input_dim=dims[i-1], output_dim=4*dims[i],
+ name="lstm_in_%d"%i)
+ lstm = LSTM(dim=dims[i], activation=activation_function,
+ name="lstm_rec_%d"%i)
+ prev = lstm.apply(linear.apply(prev))[0]
+ bricks = bricks + [linear, lstm]
+
+ top_linear = MLP(dims=[hidden_dims[-1], io_dim],
+ activations=[Softmax()],
+ name="pred_mlp")
+ bricks.append(top_linear)
+
+ out = top_linear.apply(prev.reshape((inp.shape[0], hidden_dims[-1])))
+
+ pred = out.argmax(axis=1)
+
+ cost = Softmax().categorical_cross_entropy(inp[:-1], out[1:])
+ error_rate = tensor.neq(inp[:-1], pred[1:]).mean()
+
+ # Initialize
+ for brick in bricks:
+ brick.weights_init = IsotropicGaussian(0.1)
+ brick.biases_init = Constant(0.)
+ brick.initialize()
+
+ # apply noise
+ cg = ComputationGraph([cost, error_rate])
+ noise_vars = VariableFilter(roles=[WEIGHT])(cg)
+ cg = apply_noise(cg, noise_vars, w_noise_std)
+ [cost_reg, error_rate_reg] = cg.outputs
+
+ self.cost = cost
+ self.error_rate = error_rate
+ self.cost_reg = cost_reg
+ self.error_rate_reg = error_rate_reg
+ self.pred = pred
+
diff --git a/train.py b/train.py
new file mode 100755
index 0000000..ab973a1
--- /dev/null
+++ b/train.py
@@ -0,0 +1,91 @@
+#!/usr/bin/env python
+
+import logging
+import numpy
+import sys
+import importlib
+
+from blocks.dump import load_parameter_values
+from blocks.dump import MainLoopDumpManager
+from blocks.extensions import Printing
+from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
+from blocks.extensions.plot import Plot
+from blocks.graph import ComputationGraph
+from blocks.main_loop import MainLoop
+from blocks.model import Model
+from blocks.algorithms import GradientDescent
+from theano import tensor
+
+import datastream
+# from apply_model import Apply
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger(__name__)
+
+if __name__ == "__main__":
+ if len(sys.argv) != 2:
+ print >> sys.stderr, 'Usage: %s config' % sys.argv[0]
+ sys.exit(1)
+ model_name = sys.argv[1]
+ config = importlib.import_module('%s' % model_name)
+
+
+def train_model(m, train_stream, load_location=None, save_location=None):
+
+ # Define the model
+ model = Model(m.cost)
+
+ # Load the parameters from a dumped model
+ if load_location is not None:
+ logger.info('Loading parameters...')
+ model.set_param_values(load_parameter_values(load_location))
+
+ cg = ComputationGraph(m.cost_reg)
+ algorithm = GradientDescent(cost=m.cost_reg,
+ step_rule=config.step_rule,
+ params=cg.parameters)
+ main_loop = MainLoop(
+ model=model,
+ data_stream=train_stream,
+ algorithm=algorithm,
+ extensions=[
+ TrainingDataMonitoring(
+ [m.cost_reg, m.error_rate_reg, m.cost, m.error_rate],
+ prefix='train', every_n_epochs=1*config.pt_freq),
+ Printing(every_n_epochs=1*config.pt_freq, after_epoch=False),
+ Plot(document='tr_'+model_name+'_'+config.param_desc,
+ channels=[['train_cost', 'train_cost_reg'],
+ ['train_error_rate', 'train_error_rate_reg']],
+ every_n_epochs=1*config.pt_freq, after_epoch=False)
+ ]
+ )
+ main_loop.run()
+
+ # Save the main loop
+ if save_location is not None:
+ logger.info('Saving the main loop...')
+ dump_manager = MainLoopDumpManager(save_location)
+ dump_manager.dump(main_loop)
+ logger.info('Saved')
+
+
+if __name__ == "__main__":
+ # Build datastream
+ train_stream = datastream.setup_datastream('data/logcompil.txt',
+ config.chars_per_seq,
+ config.seqs_per_epoch)
+
+ # Build model
+ m = config.Model()
+ m.cost.name = 'cost'
+ m.cost_reg.name = 'cost_reg'
+ m.error_rate.name = 'error_rate'
+ m.error_rate_reg.name = 'error_rate_reg'
+ m.pred.name = 'pred'
+
+ # Train the model
+ saveloc = 'model_data/%s' % model_name
+ train_model(m, train_stream,
+ load_location=None,
+ save_location=None)
+