aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-06-11 15:26:21 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-06-11 15:26:21 -0400
commitc3ec750f4a04e3117d658e8275dd3d91d2b0cbe4 (patch)
tree6fb3fcc9425948a13fdc421534b7a00574c1e595
parent7e8dbac77ce712846954bdd5f4bfb62b6efaf7df (diff)
downloadtaxi-c3ec750f4a04e3117d658e8275dd3d91d2b0cbe4.tar.gz
taxi-c3ec750f4a04e3117d658e8275dd3d91d2b0cbe4.zip
Add TaxiRemoveTestOnlyClients ; custom dumpmanager enabling multiprocessing
-rw-r--r--data/cut.py1
-rw-r--r--data/transformers.py10
-rw-r--r--model/mlp.py3
-rwxr-xr-xtrain.py33
4 files changed, 43 insertions, 4 deletions
diff --git a/data/cut.py b/data/cut.py
index 7853030..fc0b3f9 100644
--- a/data/cut.py
+++ b/data/cut.py
@@ -28,6 +28,7 @@ class TaxiTimeCutScheme(IterationScheme):
c.execute('SELECT trip FROM trip_times WHERE begin >= ? AND begin <= ? AND end >= ?',
(cut - 40000, cut, cut))]
l = l + part
+ random.shuffle(l)
return iter_(l)
diff --git a/data/transformers.py b/data/transformers.py
index 1b82dae..4814e9b 100644
--- a/data/transformers.py
+++ b/data/transformers.py
@@ -119,3 +119,13 @@ class TaxiExcludeTrips(Transformer):
if not data[self.id_trip_id] in self.exclude: break
return data
+class TaxiRemoveTestOnlyClients(Transformer):
+ def __init__(self, stream):
+ super(TaxiRemoveTestOnlyClients, self).__init__(stream)
+ self.id_origin_call = stream.sources.index('origin_call')
+ def get_data(self, request=None):
+ if request is not None: raise ValueError
+ x = list(next(self.child_epoch_iterator))
+ if x[self.id_origin_call] >= data.origin_call_train_size:
+ x[self.id_origin_call] = numpy.int32(0)
+ return tuple(x)
diff --git a/model/mlp.py b/model/mlp.py
index 2e0b9e5..b1e9163 100644
--- a/model/mlp.py
+++ b/model/mlp.py
@@ -70,6 +70,8 @@ class Stream(object):
stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
+ stream = MultiProcessing(stream)
+
return stream
def valid(self, req_vars):
@@ -85,6 +87,7 @@ class Stream(object):
stream = transformers.TaxiAddDateTime(stream)
stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream)
+ stream = transformers.TaxiRemoveTestOnlyClients(stream)
return Batch(stream, iteration_scheme=ConstantScheme(1))
diff --git a/train.py b/train.py
index 65f9050..876fcba 100755
--- a/train.py
+++ b/train.py
@@ -14,7 +14,8 @@ from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNo
from blocks.extensions import Printing, FinishAfter
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
from blocks.extensions.plot import Plot
-from blocks.extensions.saveload import Dump, LoadFromDump
+from blocks.extensions.saveload import LoadFromDump, Dump
+from blocks.dump import MainLoopDumpManager
from blocks.filter import VariableFilter
from blocks.graph import ComputationGraph, apply_dropout, apply_noise
from blocks.main_loop import MainLoop
@@ -53,6 +54,25 @@ class ElementwiseRemoveNotFinite(StepRule):
return step, []
+class CustomDumpManager(MainLoopDumpManager):
+ def dump(self, main_loop):
+ """Dumps the main loop to the root folder.
+ See :mod:`blocks.dump`.
+ Overwrites the old data if present.
+ """
+ if not os.path.exists(self.folder):
+ os.mkdir(self.folder)
+ self.dump_parameters(main_loop)
+ self.dump_log(main_loop)
+
+ def load(self):
+ return (self.load_parameters(), self.load_log())
+
+ def load_to(self, main_loop):
+ """Loads the dump from the root folder into the main loop."""
+ parameters, log = self.load()
+ main_loop.model.set_param_values(parameters)
+ main_loop.log = log
if __name__ == "__main__":
if len(sys.argv) != 2:
@@ -109,15 +129,20 @@ if __name__ == "__main__":
dump_path = os.path.join('model_data', model_name)
logger.info('Dump path: %s' % dump_path)
+
+ load_dump_ext = LoadFromDump(dump_path)
+ dump_ext = Dump(dump_path, every_n_batches=1000)
+ load_dump_ext.manager = CustomDumpManager(dump_path)
+ dump_ext.manager = CustomDumpManager(dump_path)
+
extensions=[TrainingDataMonitoring(monitored, prefix='train', every_n_batches=1000),
DataStreamMonitoring(monitored, valid_stream,
prefix='valid',
every_n_batches=1000),
Printing(every_n_batches=1000),
Plot(model_name, channels=plot_vars, every_n_batches=500),
- Dump(dump_path, every_n_batches=5000),
- LoadFromDump(dump_path),
- #FinishAfter(after_n_batches=2),
+ load_dump_ext,
+ dump_ext
]
main_loop = MainLoop(