diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-11 15:26:21 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-11 15:26:21 -0400 |
commit | c3ec750f4a04e3117d658e8275dd3d91d2b0cbe4 (patch) | |
tree | 6fb3fcc9425948a13fdc421534b7a00574c1e595 | |
parent | 7e8dbac77ce712846954bdd5f4bfb62b6efaf7df (diff) | |
download | taxi-c3ec750f4a04e3117d658e8275dd3d91d2b0cbe4.tar.gz taxi-c3ec750f4a04e3117d658e8275dd3d91d2b0cbe4.zip |
Add TaxiRemoveTestOnlyClients ; custom dumpmanager enabling multiprocessing
-rw-r--r-- | data/cut.py | 1 | ||||
-rw-r--r-- | data/transformers.py | 10 | ||||
-rw-r--r-- | model/mlp.py | 3 | ||||
-rwxr-xr-x | train.py | 33 |
4 files changed, 43 insertions, 4 deletions
diff --git a/data/cut.py b/data/cut.py index 7853030..fc0b3f9 100644 --- a/data/cut.py +++ b/data/cut.py @@ -28,6 +28,7 @@ class TaxiTimeCutScheme(IterationScheme): c.execute('SELECT trip FROM trip_times WHERE begin >= ? AND begin <= ? AND end >= ?', (cut - 40000, cut, cut))] l = l + part + random.shuffle(l) return iter_(l) diff --git a/data/transformers.py b/data/transformers.py index 1b82dae..4814e9b 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -119,3 +119,13 @@ class TaxiExcludeTrips(Transformer): if not data[self.id_trip_id] in self.exclude: break return data +class TaxiRemoveTestOnlyClients(Transformer): + def __init__(self, stream): + super(TaxiRemoveTestOnlyClients, self).__init__(stream) + self.id_origin_call = stream.sources.index('origin_call') + def get_data(self, request=None): + if request is not None: raise ValueError + x = list(next(self.child_epoch_iterator)) + if x[self.id_origin_call] >= data.origin_call_train_size: + x[self.id_origin_call] = numpy.int32(0) + return tuple(x) diff --git a/model/mlp.py b/model/mlp.py index 2e0b9e5..b1e9163 100644 --- a/model/mlp.py +++ b/model/mlp.py @@ -70,6 +70,8 @@ class Stream(object): stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) + stream = MultiProcessing(stream) + return stream def valid(self, req_vars): @@ -85,6 +87,7 @@ class Stream(object): stream = transformers.TaxiAddDateTime(stream) stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream) + stream = transformers.TaxiRemoveTestOnlyClients(stream) return Batch(stream, iteration_scheme=ConstantScheme(1)) @@ -14,7 +14,8 @@ from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNo from blocks.extensions import Printing, FinishAfter from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring from blocks.extensions.plot import Plot -from blocks.extensions.saveload import Dump, LoadFromDump +from blocks.extensions.saveload import LoadFromDump, Dump +from blocks.dump import MainLoopDumpManager from blocks.filter import VariableFilter from blocks.graph import ComputationGraph, apply_dropout, apply_noise from blocks.main_loop import MainLoop @@ -53,6 +54,25 @@ class ElementwiseRemoveNotFinite(StepRule): return step, [] +class CustomDumpManager(MainLoopDumpManager): + def dump(self, main_loop): + """Dumps the main loop to the root folder. + See :mod:`blocks.dump`. + Overwrites the old data if present. + """ + if not os.path.exists(self.folder): + os.mkdir(self.folder) + self.dump_parameters(main_loop) + self.dump_log(main_loop) + + def load(self): + return (self.load_parameters(), self.load_log()) + + def load_to(self, main_loop): + """Loads the dump from the root folder into the main loop.""" + parameters, log = self.load() + main_loop.model.set_param_values(parameters) + main_loop.log = log if __name__ == "__main__": if len(sys.argv) != 2: @@ -109,15 +129,20 @@ if __name__ == "__main__": dump_path = os.path.join('model_data', model_name) logger.info('Dump path: %s' % dump_path) + + load_dump_ext = LoadFromDump(dump_path) + dump_ext = Dump(dump_path, every_n_batches=1000) + load_dump_ext.manager = CustomDumpManager(dump_path) + dump_ext.manager = CustomDumpManager(dump_path) + extensions=[TrainingDataMonitoring(monitored, prefix='train', every_n_batches=1000), DataStreamMonitoring(monitored, valid_stream, prefix='valid', every_n_batches=1000), Printing(every_n_batches=1000), Plot(model_name, channels=plot_vars, every_n_batches=500), - Dump(dump_path, every_n_batches=5000), - LoadFromDump(dump_path), - #FinishAfter(after_n_batches=2), + load_dump_ext, + dump_ext ] main_loop = MainLoop( |