aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-05-18 16:22:00 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-05-18 16:22:00 -0400
commit6d946f29f7548c75e97f30c4356dbac200ee6cce (patch)
tree387e586c7ad0c1a0167d21451c9a8c877cf3ef0e /train.py
parent1e6d08b0c9ac5983691b182631c71e9d46ee71cc (diff)
downloadtaxi-6d946f29f7548c75e97f30c4356dbac200ee6cce.tar.gz
taxi-6d946f29f7548c75e97f30c4356dbac200ee6cce.zip
Refactor models, clean the code and separate training from testing.
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py174
1 files changed, 57 insertions, 117 deletions
diff --git a/train.py b/train.py
index 96dd798..97d8bb3 100755
--- a/train.py
+++ b/train.py
@@ -1,32 +1,25 @@
#!/usr/bin/env python
-import sys
-import logging
import importlib
+import logging
+import operator
+import os
+import sys
+from functools import reduce
-import csv
-
-from picklable_itertools.extras import equizip
-
-from blocks.model import Model
-
-from fuel.transformers import Batch
-from fuel.streams import DataStream
-from fuel.schemes import ConstantScheme, ShuffledExampleScheme
-
-from blocks.algorithms import CompositeRule, RemoveNotFinite, GradientDescent, AdaDelta, Momentum
-from blocks.graph import ComputationGraph, apply_dropout
-from blocks.main_loop import MainLoop
+from blocks import roles
+from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite
from blocks.extensions import Printing, FinishAfter
-from blocks.extensions.saveload import Dump, LoadFromDump, Checkpoint
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
from blocks.extensions.plot import Plot
+from blocks.extensions.saveload import Dump, LoadFromDump
+from blocks.filter import VariableFilter
+from blocks.graph import ComputationGraph, apply_dropout, apply_noise
+from blocks.main_loop import MainLoop
+from blocks.model import Model
-from theano import tensor
-from data import transformers
-from data.hdf5 import TaxiDataset, TaxiStream
-import apply_model
+logger = logging.getLogger(__name__)
if __name__ == "__main__":
if len(sys.argv) != 2:
@@ -35,123 +28,70 @@ if __name__ == "__main__":
model_name = sys.argv[1]
config = importlib.import_module('.%s' % model_name, 'config')
-def compile_valid_trip_ids():
- valid = TaxiDataset(config.valid_set, 'valid.hdf5', sources=('trip_id',))
- ids = valid.get_data(None, slice(0, valid.num_examples))
- return set(ids[0])
-
-def setup_train_stream(req_vars, valid_trips_ids):
- train = TaxiDataset('train')
- train = DataStream(train, iteration_scheme=ShuffledExampleScheme(train.num_examples))
-
- train = transformers.TaxiExcludeTrips(valid_trips_ids, train)
- train = transformers.TaxiGenerateSplits(train, max_splits=100)
-
- train = transformers.TaxiAddDateTime(train)
- train = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, train)
- train = transformers.Select(train, tuple(req_vars))
-
- train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size))
-
- return train_stream
-
-def setup_valid_stream(req_vars):
- valid = TaxiStream(config.valid_set, 'valid.hdf5')
+ logger.info('# Configuration: %s' % config.__name__)
+ for key in dir(config):
+ if not key.startswith('__') and isinstance(getattr(config, key), (int, str, list, tuple)):
+ logger.info(' %20s %s' % (key, str(getattr(config, key))))
- valid = transformers.TaxiAddDateTime(valid)
- valid = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, valid)
- valid = transformers.Select(valid, tuple(req_vars))
+ model = config.Model(config)
+ model.initialize()
- valid_stream = Batch(valid, iteration_scheme=ConstantScheme(1000))
-
- return valid_stream
-
-def setup_test_stream(req_vars):
- test = TaxiStream('test')
-
- test = transformers.TaxiAddDateTime(test)
- test = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, test)
- test = transformers.Select(test, tuple(req_vars))
-
- test_stream = Batch(test, iteration_scheme=ConstantScheme(1000))
-
- return test_stream
-
-
-def main():
- model = config.model.Model(config)
-
- cost = model.cost
- outputs = model.outputs
-
- req_vars = model.require_inputs + model.pred_vars
- req_vars_test = model.require_inputs + [ 'trip_id' ]
+ stream = config.Stream(config)
+ inputs = stream.inputs()
+ req_vars = model.cost.inputs
- valid_trips_ids = compile_valid_trip_ids()
- train_stream = setup_train_stream(req_vars, valid_trips_ids)
- valid_stream = setup_valid_stream(req_vars)
+ train_stream = stream.train(req_vars)
+ valid_stream = stream.valid(req_vars)
- # Training
+ cost = model.cost(**inputs)
cg = ComputationGraph(cost)
+ unmonitor = set()
+ if hasattr(config, 'dropout') and config.dropout < 1.0:
+ unmonitor.update(VariableFilter(roles=[roles.COST])(cg.variables))
+ cg = apply_dropout(cg, config.dropout_inputs(cg), config.dropout)
+ if hasattr(config, 'noise') and config.noise > 0.0:
+ unmonitor.update(VariableFilter(roles=[roles.COST])(cg.variables))
+ cg = apply_noise(cg, config.noise_inputs(cg), config.noise)
+ cost = cg.outputs[0]
+ cg = Model(cost)
+
+ logger.info('# Parameter shapes:')
+ parameters_size = 0
+ for key, value in cg.get_params().iteritems():
+ logger.info(' %20s %s' % (value.get_value().shape, key))
+ parameters_size += reduce(operator.mul, value.get_value().shape, 1)
+ logger.info('Total number of parameters: %d in %d matrices' % (parameters_size, len(cg.get_params())))
params = cg.parameters
-
algorithm = GradientDescent(
cost=cost,
step_rule=CompositeRule([
RemoveNotFinite(),
- #AdaDelta(decay_rate=0.95),
- Momentum(learning_rate=config.learning_rate, momentum=config.momentum),
- ]),
+ AdaDelta(),
+ ]),
params=params)
- plot_vars = [['valid_' + x.name for x in model.monitor]]
- print "Plot: ", plot_vars
-
- extensions=[TrainingDataMonitoring(model.monitor, prefix='train', every_n_batches=1000),
- DataStreamMonitoring(model.monitor, valid_stream,
+ monitored = set([cost] + VariableFilter(roles=[roles.COST])(cg.variables)) - unmonitor
+ plot_vars = [['valid_' + x.name for x in monitored]]
+ logger.info('Plotted variables: %s' % str(plot_vars))
+
+ dump_path = os.path.join('model_data', model_name)
+ logger.info('Dump path: %s' % dump_path)
+ extensions=[TrainingDataMonitoring(monitored, prefix='train', every_n_batches=1000),
+ DataStreamMonitoring(monitored, valid_stream,
prefix='valid',
- every_n_batches=500),
- Printing(every_n_batches=500),
+ every_n_batches=1000),
+ Printing(every_n_batches=1000),
Plot(model_name, channels=plot_vars, every_n_batches=500),
- # Checkpoint('model.pkl', every_n_batches=100),
- Dump('model_data/' + model_name, every_n_batches=500),
- LoadFromDump('model_data/' + model_name),
- # FinishAfter(after_epoch=4),
+ Dump(dump_path, every_n_batches=5000),
+ LoadFromDump(dump_path),
+ #FinishAfter(after_n_batches=2),
]
main_loop = MainLoop(
- model=Model([cost]),
+ model=cg,
data_stream=train_stream,
algorithm=algorithm,
extensions=extensions)
main_loop.run()
main_loop.profile.report()
-
- # Produce an output on the test data
- test_stream = setup_test_stream(req_vars_test)
-
- if 'destination_longitude' in model.pred_vars:
- dest_outfile = open("output/test-dest-output-%s.csv" % model_name, "w")
- dest_outcsv = csv.writer(dest_outfile)
- dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"])
- if 'travel_time' in model.pred_vars:
- time_outfile = open("output/test-time-output-%s.csv" % model_name, "w")
- time_outcsv = csv.writer(time_outfile)
- time_outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"])
-
- for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']):
- outputs = out['outputs']
- for i, trip in enumerate(out['trip_id']):
- if model.pred_vars == ['travel_time']:
- time_outcsv.writerow([trip, int(outputs[i])])
- else:
- dest_outcsv.writerow([trip, repr(outputs[i, 0]), repr(outputs[i, 1])])
- if 'travel_time' in model.pred_vars:
- time_outcsv.writerow([trip, int(outputs[i, 2])])
-
-
-if __name__ == "__main__":
- logging.basicConfig(level=logging.INFO)
- main()
-