diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-09 13:42:08 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-09 13:42:08 -0400 |
commit | 6a5ed2e43a5885eeb3c5e202ed5bb473f6065401 (patch) | |
tree | f5541c64fe0fc4eaf3567d716eb71c2dc5fb1ea8 | |
download | text-rnn-6a5ed2e43a5885eeb3c5e202ed5bb473f6065401.tar.gz text-rnn-6a5ed2e43a5885eeb3c5e202ed5bb473f6065401.zip |
First commit
-rw-r--r-- | .gitignore | 3 | ||||
-rw-r--r-- | __init__.py | 0 | ||||
-rw-r--r-- | datastream.py | 76 | ||||
-rw-r--r-- | lstm.py | 76 | ||||
-rwxr-xr-x | train.py | 91 |
5 files changed, 246 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cec3f1e --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +*.pyc +*.swp +data/* diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/__init__.py diff --git a/datastream.py b/datastream.py new file mode 100644 index 0000000..5d9441f --- /dev/null +++ b/datastream.py @@ -0,0 +1,76 @@ +import logging +import random +import numpy + +import cPickle + +from picklable_itertools import iter_ + +from fuel.datasets import Dataset +from fuel.streams import DataStream +from fuel.schemes import IterationScheme +from fuel.transformers import Transformer + +import sys +import os + +logging.basicConfig(level='INFO') +logger = logging.getLogger(__name__) + + +class BinaryFileDataset(Dataset): + def __init__(self, filename, **kwargs): + self.provides_sources= ('bytes',) + + self.f = open(filename, "rb") + + super(BinaryFileDataset, self).__init__(**kwargs) + + def get_data(self, state=None, request=None): + if request is None: + raise ValueError("Expected a request: begin, length") + + bg, ln = request + self.f.seek(bg) + return (self.f.read(ln),) + + def num_examples(self): + return os.fstat(self.f.fileno()).st_size + +class RandomBlockIterator(IterationScheme): + def __init__(self, item_range, seq_len, num_seqs_per_epoch, **kwargs): + self.seq_len = seq_len + self.num_seqs = num_seqs_per_epoch + self.item_range = item_range + + super(RandomBlockIterator, self).__init__(**kwargs) + + def get_request_iterator(self): + l = [(random.randrange(0, self.item_range - self.seq_len + 1), self.seq_len) + for _ in xrange(self.num_seqs)] + return iter_(l) + +class BytesToIndices(Transformer): + def __init__(self, stream, **kwargs): + self.sources = ('bytes',) + super(BytesToIndices, self).__init__(stream, **kwargs) + + def get_data(self, request=None): + if request is not None: + raise ValueError('Unsupported: request') + data = next(self.child_epoch_iterator) + return numpy.array([ord(i) for i in data[0]], dtype='int16'), + +def setup_datastream(filename, seq_len, num_seqs_per_epoch=100): + ds = BinaryFileDataset(filename) + it = RandomBlockIterator(ds.num_examples(), seq_len, num_seqs_per_epoch) + stream = DataStream(ds, iteration_scheme=it) + stream = BytesToIndices(stream) + + return stream + +if __name__ == "__main__": + # Test + stream = setup_datastream("data/logcompil.txt", 100) + print(next(stream.get_epoch_iterator())) + @@ -0,0 +1,76 @@ +import theano +from theano import tensor + +from blocks.algorithms import Momentum, AdaDelta +from blocks.bricks import Tanh, Softmax, Linear, MLP +from blocks.bricks.recurrent import LSTM +from blocks.initialization import IsotropicGaussian, Constant + +from blocks.filter import VariableFilter +from blocks.roles import WEIGHT +from blocks.graph import ComputationGraph, apply_noise + +chars_per_seq = 100 +seqs_per_epoch = 1 + +io_dim = 256 + +hidden_dims = [200, 500] +activation_function = Tanh() + +w_noise_std = 0.01 + +step_rule = AdaDelta() + +pt_freq = 1 + +param_desc = '' # todo + +class Model(): + def __init__(self): + inp = tensor.lvector('bytes') + + in_onehot = tensor.eq(tensor.arange(io_dim, dtype='int16').reshape((1, io_dim)), + inp[:, None]) + + dims = [io_dim] + hidden_dims + prev = in_onehot[None, :, :] + bricks = [] + for i in xrange(1, len(dims)): + linear = Linear(input_dim=dims[i-1], output_dim=4*dims[i], + name="lstm_in_%d"%i) + lstm = LSTM(dim=dims[i], activation=activation_function, + name="lstm_rec_%d"%i) + prev = lstm.apply(linear.apply(prev))[0] + bricks = bricks + [linear, lstm] + + top_linear = MLP(dims=[hidden_dims[-1], io_dim], + activations=[Softmax()], + name="pred_mlp") + bricks.append(top_linear) + + out = top_linear.apply(prev.reshape((inp.shape[0], hidden_dims[-1]))) + + pred = out.argmax(axis=1) + + cost = Softmax().categorical_cross_entropy(inp[:-1], out[1:]) + error_rate = tensor.neq(inp[:-1], pred[1:]).mean() + + # Initialize + for brick in bricks: + brick.weights_init = IsotropicGaussian(0.1) + brick.biases_init = Constant(0.) + brick.initialize() + + # apply noise + cg = ComputationGraph([cost, error_rate]) + noise_vars = VariableFilter(roles=[WEIGHT])(cg) + cg = apply_noise(cg, noise_vars, w_noise_std) + [cost_reg, error_rate_reg] = cg.outputs + + self.cost = cost + self.error_rate = error_rate + self.cost_reg = cost_reg + self.error_rate_reg = error_rate_reg + self.pred = pred + diff --git a/train.py b/train.py new file mode 100755 index 0000000..ab973a1 --- /dev/null +++ b/train.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python + +import logging +import numpy +import sys +import importlib + +from blocks.dump import load_parameter_values +from blocks.dump import MainLoopDumpManager +from blocks.extensions import Printing +from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring +from blocks.extensions.plot import Plot +from blocks.graph import ComputationGraph +from blocks.main_loop import MainLoop +from blocks.model import Model +from blocks.algorithms import GradientDescent +from theano import tensor + +import datastream +# from apply_model import Apply + +logging.basicConfig(level='INFO') +logger = logging.getLogger(__name__) + +if __name__ == "__main__": + if len(sys.argv) != 2: + print >> sys.stderr, 'Usage: %s config' % sys.argv[0] + sys.exit(1) + model_name = sys.argv[1] + config = importlib.import_module('%s' % model_name) + + +def train_model(m, train_stream, load_location=None, save_location=None): + + # Define the model + model = Model(m.cost) + + # Load the parameters from a dumped model + if load_location is not None: + logger.info('Loading parameters...') + model.set_param_values(load_parameter_values(load_location)) + + cg = ComputationGraph(m.cost_reg) + algorithm = GradientDescent(cost=m.cost_reg, + step_rule=config.step_rule, + params=cg.parameters) + main_loop = MainLoop( + model=model, + data_stream=train_stream, + algorithm=algorithm, + extensions=[ + TrainingDataMonitoring( + [m.cost_reg, m.error_rate_reg, m.cost, m.error_rate], + prefix='train', every_n_epochs=1*config.pt_freq), + Printing(every_n_epochs=1*config.pt_freq, after_epoch=False), + Plot(document='tr_'+model_name+'_'+config.param_desc, + channels=[['train_cost', 'train_cost_reg'], + ['train_error_rate', 'train_error_rate_reg']], + every_n_epochs=1*config.pt_freq, after_epoch=False) + ] + ) + main_loop.run() + + # Save the main loop + if save_location is not None: + logger.info('Saving the main loop...') + dump_manager = MainLoopDumpManager(save_location) + dump_manager.dump(main_loop) + logger.info('Saved') + + +if __name__ == "__main__": + # Build datastream + train_stream = datastream.setup_datastream('data/logcompil.txt', + config.chars_per_seq, + config.seqs_per_epoch) + + # Build model + m = config.Model() + m.cost.name = 'cost' + m.cost_reg.name = 'cost_reg' + m.error_rate.name = 'error_rate' + m.error_rate_reg.name = 'error_rate_reg' + m.pred.name = 'pred' + + # Train the model + saveloc = 'model_data/%s' % model_name + train_model(m, train_stream, + load_location=None, + save_location=None) + |