diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-05-18 16:22:00 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-05-18 16:22:00 -0400 |
commit | 6d946f29f7548c75e97f30c4356dbac200ee6cce (patch) | |
tree | 387e586c7ad0c1a0167d21451c9a8c877cf3ef0e /train.py | |
parent | 1e6d08b0c9ac5983691b182631c71e9d46ee71cc (diff) | |
download | taxi-6d946f29f7548c75e97f30c4356dbac200ee6cce.tar.gz taxi-6d946f29f7548c75e97f30c4356dbac200ee6cce.zip |
Refactor models, clean the code and separate training from testing.
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 174 |
1 files changed, 57 insertions, 117 deletions
@@ -1,32 +1,25 @@ #!/usr/bin/env python -import sys -import logging import importlib +import logging +import operator +import os +import sys +from functools import reduce -import csv - -from picklable_itertools.extras import equizip - -from blocks.model import Model - -from fuel.transformers import Batch -from fuel.streams import DataStream -from fuel.schemes import ConstantScheme, ShuffledExampleScheme - -from blocks.algorithms import CompositeRule, RemoveNotFinite, GradientDescent, AdaDelta, Momentum -from blocks.graph import ComputationGraph, apply_dropout -from blocks.main_loop import MainLoop +from blocks import roles +from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite from blocks.extensions import Printing, FinishAfter -from blocks.extensions.saveload import Dump, LoadFromDump, Checkpoint from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring from blocks.extensions.plot import Plot +from blocks.extensions.saveload import Dump, LoadFromDump +from blocks.filter import VariableFilter +from blocks.graph import ComputationGraph, apply_dropout, apply_noise +from blocks.main_loop import MainLoop +from blocks.model import Model -from theano import tensor -from data import transformers -from data.hdf5 import TaxiDataset, TaxiStream -import apply_model +logger = logging.getLogger(__name__) if __name__ == "__main__": if len(sys.argv) != 2: @@ -35,123 +28,70 @@ if __name__ == "__main__": model_name = sys.argv[1] config = importlib.import_module('.%s' % model_name, 'config') -def compile_valid_trip_ids(): - valid = TaxiDataset(config.valid_set, 'valid.hdf5', sources=('trip_id',)) - ids = valid.get_data(None, slice(0, valid.num_examples)) - return set(ids[0]) - -def setup_train_stream(req_vars, valid_trips_ids): - train = TaxiDataset('train') - train = DataStream(train, iteration_scheme=ShuffledExampleScheme(train.num_examples)) - - train = transformers.TaxiExcludeTrips(valid_trips_ids, train) - train = transformers.TaxiGenerateSplits(train, max_splits=100) - - train = transformers.TaxiAddDateTime(train) - train = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, train) - train = transformers.Select(train, tuple(req_vars)) - - train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size)) - - return train_stream - -def setup_valid_stream(req_vars): - valid = TaxiStream(config.valid_set, 'valid.hdf5') + logger.info('# Configuration: %s' % config.__name__) + for key in dir(config): + if not key.startswith('__') and isinstance(getattr(config, key), (int, str, list, tuple)): + logger.info(' %20s %s' % (key, str(getattr(config, key)))) - valid = transformers.TaxiAddDateTime(valid) - valid = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, valid) - valid = transformers.Select(valid, tuple(req_vars)) + model = config.Model(config) + model.initialize() - valid_stream = Batch(valid, iteration_scheme=ConstantScheme(1000)) - - return valid_stream - -def setup_test_stream(req_vars): - test = TaxiStream('test') - - test = transformers.TaxiAddDateTime(test) - test = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, test) - test = transformers.Select(test, tuple(req_vars)) - - test_stream = Batch(test, iteration_scheme=ConstantScheme(1000)) - - return test_stream - - -def main(): - model = config.model.Model(config) - - cost = model.cost - outputs = model.outputs - - req_vars = model.require_inputs + model.pred_vars - req_vars_test = model.require_inputs + [ 'trip_id' ] + stream = config.Stream(config) + inputs = stream.inputs() + req_vars = model.cost.inputs - valid_trips_ids = compile_valid_trip_ids() - train_stream = setup_train_stream(req_vars, valid_trips_ids) - valid_stream = setup_valid_stream(req_vars) + train_stream = stream.train(req_vars) + valid_stream = stream.valid(req_vars) - # Training + cost = model.cost(**inputs) cg = ComputationGraph(cost) + unmonitor = set() + if hasattr(config, 'dropout') and config.dropout < 1.0: + unmonitor.update(VariableFilter(roles=[roles.COST])(cg.variables)) + cg = apply_dropout(cg, config.dropout_inputs(cg), config.dropout) + if hasattr(config, 'noise') and config.noise > 0.0: + unmonitor.update(VariableFilter(roles=[roles.COST])(cg.variables)) + cg = apply_noise(cg, config.noise_inputs(cg), config.noise) + cost = cg.outputs[0] + cg = Model(cost) + + logger.info('# Parameter shapes:') + parameters_size = 0 + for key, value in cg.get_params().iteritems(): + logger.info(' %20s %s' % (value.get_value().shape, key)) + parameters_size += reduce(operator.mul, value.get_value().shape, 1) + logger.info('Total number of parameters: %d in %d matrices' % (parameters_size, len(cg.get_params()))) params = cg.parameters - algorithm = GradientDescent( cost=cost, step_rule=CompositeRule([ RemoveNotFinite(), - #AdaDelta(decay_rate=0.95), - Momentum(learning_rate=config.learning_rate, momentum=config.momentum), - ]), + AdaDelta(), + ]), params=params) - plot_vars = [['valid_' + x.name for x in model.monitor]] - print "Plot: ", plot_vars - - extensions=[TrainingDataMonitoring(model.monitor, prefix='train', every_n_batches=1000), - DataStreamMonitoring(model.monitor, valid_stream, + monitored = set([cost] + VariableFilter(roles=[roles.COST])(cg.variables)) - unmonitor + plot_vars = [['valid_' + x.name for x in monitored]] + logger.info('Plotted variables: %s' % str(plot_vars)) + + dump_path = os.path.join('model_data', model_name) + logger.info('Dump path: %s' % dump_path) + extensions=[TrainingDataMonitoring(monitored, prefix='train', every_n_batches=1000), + DataStreamMonitoring(monitored, valid_stream, prefix='valid', - every_n_batches=500), - Printing(every_n_batches=500), + every_n_batches=1000), + Printing(every_n_batches=1000), Plot(model_name, channels=plot_vars, every_n_batches=500), - # Checkpoint('model.pkl', every_n_batches=100), - Dump('model_data/' + model_name, every_n_batches=500), - LoadFromDump('model_data/' + model_name), - # FinishAfter(after_epoch=4), + Dump(dump_path, every_n_batches=5000), + LoadFromDump(dump_path), + #FinishAfter(after_n_batches=2), ] main_loop = MainLoop( - model=Model([cost]), + model=cg, data_stream=train_stream, algorithm=algorithm, extensions=extensions) main_loop.run() main_loop.profile.report() - - # Produce an output on the test data - test_stream = setup_test_stream(req_vars_test) - - if 'destination_longitude' in model.pred_vars: - dest_outfile = open("output/test-dest-output-%s.csv" % model_name, "w") - dest_outcsv = csv.writer(dest_outfile) - dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"]) - if 'travel_time' in model.pred_vars: - time_outfile = open("output/test-time-output-%s.csv" % model_name, "w") - time_outcsv = csv.writer(time_outfile) - time_outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"]) - - for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']): - outputs = out['outputs'] - for i, trip in enumerate(out['trip_id']): - if model.pred_vars == ['travel_time']: - time_outcsv.writerow([trip, int(outputs[i])]) - else: - dest_outcsv.writerow([trip, repr(outputs[i, 0]), repr(outputs[i, 1])]) - if 'travel_time' in model.pred_vars: - time_outcsv.writerow([trip, int(outputs[i, 2])]) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - main() - |