aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-22 15:51:26 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-22 15:51:47 -0400
commit57fe795d14e70c06c9bdbe6fe903588b5f75474e (patch)
treed7b0de1569a67dfc55dc6481c35e976d22572ebb
parent448e848796757ad9f0a2f681886f868b8f22e81f (diff)
downloadtaxi-57fe795d14e70c06c9bdbe6fe903588b5f75474e.tar.gz
taxi-57fe795d14e70c06c9bdbe6fe903588b5f75474e.zip
Add parametrizability for how the training data is presented
-rw-r--r--config/dest_simple_mlp_2_cs.py1
-rw-r--r--config/dest_simple_mlp_2_cswdt.py1
-rw-r--r--config/dest_simple_mlp_2_noembed.py1
-rw-r--r--config/dest_simple_mlp_tgtcls_0_cs.py1
-rw-r--r--config/dest_simple_mlp_tgtcls_1_cs.py1
-rw-r--r--config/dest_simple_mlp_tgtcls_1_cswdt.py1
-rw-r--r--config/dest_simple_mlp_tgtcls_1_cswdtx.py6
-rw-r--r--config/dest_simple_mlp_tgtcls_1_cswdtx_alexandre.py1
-rw-r--r--config/joint_simple_mlp_tgtcls_111_cswdtx.py1
-rw-r--r--config/joint_simple_mlp_tgtcls_111_cswdtx_bigger.py1
-rw-r--r--config/joint_simple_mlp_tgtcls_111_cswdtx_bigger_dropout.py1
-rw-r--r--config/joint_simple_mlp_tgtcls_111_cswdtx_noise_dout.py1
-rw-r--r--config/joint_simple_mlp_tgtcls_1_cswdtx.py1
-rw-r--r--config/joint_simple_mlp_tgtcls_1_cswdtx_bigger.py1
-rw-r--r--config/time_simple_mlp_1.py1
-rw-r--r--config/time_simple_mlp_2_cswdtx.py1
-rw-r--r--config/time_simple_mlp_tgtcls_2_cswdtx.py1
-rw-r--r--model/mlp.py12
18 files changed, 29 insertions, 5 deletions
diff --git a/config/dest_simple_mlp_2_cs.py b/config/dest_simple_mlp_2_cs.py
index 7c121a4..4fd8c4a 100644
--- a/config/dest_simple_mlp_2_cs.py
+++ b/config/dest_simple_mlp_2_cs.py
@@ -26,3 +26,4 @@ momentum = 0.99
batch_size = 32
valid_set = 'cuts/test_times_0'
+max_splits = 100
diff --git a/config/dest_simple_mlp_2_cswdt.py b/config/dest_simple_mlp_2_cswdt.py
index d904b8c..4563176 100644
--- a/config/dest_simple_mlp_2_cswdt.py
+++ b/config/dest_simple_mlp_2_cswdt.py
@@ -30,3 +30,4 @@ momentum = 0.99
batch_size = 32
valid_set = 'cuts/test_times_0'
+max_splits = 100
diff --git a/config/dest_simple_mlp_2_noembed.py b/config/dest_simple_mlp_2_noembed.py
index 9173a8e..de3b3e4 100644
--- a/config/dest_simple_mlp_2_noembed.py
+++ b/config/dest_simple_mlp_2_noembed.py
@@ -23,3 +23,4 @@ momentum = 0.99
batch_size = 32
valid_set = 'cuts/test_times_0'
+max_splits = 100
diff --git a/config/dest_simple_mlp_tgtcls_0_cs.py b/config/dest_simple_mlp_tgtcls_0_cs.py
index 15c7259..0808661 100644
--- a/config/dest_simple_mlp_tgtcls_0_cs.py
+++ b/config/dest_simple_mlp_tgtcls_0_cs.py
@@ -31,3 +31,4 @@ momentum = 0.99
batch_size = 32
valid_set = 'cuts/test_times_0'
+max_splits = 100
diff --git a/config/dest_simple_mlp_tgtcls_1_cs.py b/config/dest_simple_mlp_tgtcls_1_cs.py
index d047bcf..8bdd028 100644
--- a/config/dest_simple_mlp_tgtcls_1_cs.py
+++ b/config/dest_simple_mlp_tgtcls_1_cs.py
@@ -31,3 +31,4 @@ momentum = 0.99
batch_size = 32
valid_set = 'cuts/test_times_0'
+max_splits = 100
diff --git a/config/dest_simple_mlp_tgtcls_1_cswdt.py b/config/dest_simple_mlp_tgtcls_1_cswdt.py
index 8811993..38fa62d 100644
--- a/config/dest_simple_mlp_tgtcls_1_cswdt.py
+++ b/config/dest_simple_mlp_tgtcls_1_cswdt.py
@@ -35,3 +35,4 @@ momentum = 0.99
batch_size = 32
valid_set = 'cuts/test_times_0'
+max_splits = 100
diff --git a/config/dest_simple_mlp_tgtcls_1_cswdtx.py b/config/dest_simple_mlp_tgtcls_1_cswdtx.py
index de076ef..f350457 100644
--- a/config/dest_simple_mlp_tgtcls_1_cswdtx.py
+++ b/config/dest_simple_mlp_tgtcls_1_cswdtx.py
@@ -33,6 +33,10 @@ mlp_biases_init = Constant(0.001)
learning_rate = 0.0001
momentum = 0.99
-batch_size = 32
+batch_size = 100
+
+use_cuts_for_training = True
+max_splits = 1
valid_set = 'cuts/test_times_0'
+
diff --git a/config/dest_simple_mlp_tgtcls_1_cswdtx_alexandre.py b/config/dest_simple_mlp_tgtcls_1_cswdtx_alexandre.py
index 90200e9..c9748ad 100644
--- a/config/dest_simple_mlp_tgtcls_1_cswdtx_alexandre.py
+++ b/config/dest_simple_mlp_tgtcls_1_cswdtx_alexandre.py
@@ -36,3 +36,4 @@ momentum = 0.9
batch_size = 200
valid_set = 'cuts/test_times_0'
+max_splits = 100
diff --git a/config/joint_simple_mlp_tgtcls_111_cswdtx.py b/config/joint_simple_mlp_tgtcls_111_cswdtx.py
index a18e4ff..b969da1 100644
--- a/config/joint_simple_mlp_tgtcls_111_cswdtx.py
+++ b/config/joint_simple_mlp_tgtcls_111_cswdtx.py
@@ -53,3 +53,4 @@ momentum = 0.99
batch_size = 200
valid_set = 'cuts/test_times_0'
+max_splits = 100
diff --git a/config/joint_simple_mlp_tgtcls_111_cswdtx_bigger.py b/config/joint_simple_mlp_tgtcls_111_cswdtx_bigger.py
index 0403197..0077881 100644
--- a/config/joint_simple_mlp_tgtcls_111_cswdtx_bigger.py
+++ b/config/joint_simple_mlp_tgtcls_111_cswdtx_bigger.py
@@ -54,3 +54,4 @@ batch_size = 200
valid_set = 'cuts/test_times_0'
+max_splits = 100
diff --git a/config/joint_simple_mlp_tgtcls_111_cswdtx_bigger_dropout.py b/config/joint_simple_mlp_tgtcls_111_cswdtx_bigger_dropout.py
index 937122d..bc5121b 100644
--- a/config/joint_simple_mlp_tgtcls_111_cswdtx_bigger_dropout.py
+++ b/config/joint_simple_mlp_tgtcls_111_cswdtx_bigger_dropout.py
@@ -57,3 +57,4 @@ batch_size = 200
valid_set = 'cuts/test_times_0'
+max_splits = 100
diff --git a/config/joint_simple_mlp_tgtcls_111_cswdtx_noise_dout.py b/config/joint_simple_mlp_tgtcls_111_cswdtx_noise_dout.py
index 6d44c10..fd4dabf 100644
--- a/config/joint_simple_mlp_tgtcls_111_cswdtx_noise_dout.py
+++ b/config/joint_simple_mlp_tgtcls_111_cswdtx_noise_dout.py
@@ -60,3 +60,4 @@ noise = 0.01
noise_inputs = VariableFilter(roles=[roles.PARAMETER])
valid_set = 'cuts/test_times_0'
+max_splits = 100
diff --git a/config/joint_simple_mlp_tgtcls_1_cswdtx.py b/config/joint_simple_mlp_tgtcls_1_cswdtx.py
index 1874444..1e21000 100644
--- a/config/joint_simple_mlp_tgtcls_1_cswdtx.py
+++ b/config/joint_simple_mlp_tgtcls_1_cswdtx.py
@@ -53,3 +53,4 @@ momentum = 0.99
batch_size = 200
valid_set = 'cuts/test_times_0'
+max_splits = 100
diff --git a/config/joint_simple_mlp_tgtcls_1_cswdtx_bigger.py b/config/joint_simple_mlp_tgtcls_1_cswdtx_bigger.py
index e96b44c..a8242a7 100644
--- a/config/joint_simple_mlp_tgtcls_1_cswdtx_bigger.py
+++ b/config/joint_simple_mlp_tgtcls_1_cswdtx_bigger.py
@@ -53,3 +53,4 @@ momentum = 0.99
batch_size = 200
valid_set = 'cuts/test_times_0'
+max_splits = 100
diff --git a/config/time_simple_mlp_1.py b/config/time_simple_mlp_1.py
index 172e098..6cfe510 100644
--- a/config/time_simple_mlp_1.py
+++ b/config/time_simple_mlp_1.py
@@ -26,3 +26,4 @@ momentum = 0.99
batch_size = 32
valid_set = 'cuts/test_times_0'
+max_splits = 100
diff --git a/config/time_simple_mlp_2_cswdtx.py b/config/time_simple_mlp_2_cswdtx.py
index 2ec28c6..4a1a92c 100644
--- a/config/time_simple_mlp_2_cswdtx.py
+++ b/config/time_simple_mlp_2_cswdtx.py
@@ -33,3 +33,4 @@ momentum = 0.99
batch_size = 32
valid_set = 'cuts/test_times_0'
+max_splits = 100
diff --git a/config/time_simple_mlp_tgtcls_2_cswdtx.py b/config/time_simple_mlp_tgtcls_2_cswdtx.py
index 608ed7e..d8a1281 100644
--- a/config/time_simple_mlp_tgtcls_2_cswdtx.py
+++ b/config/time_simple_mlp_tgtcls_2_cswdtx.py
@@ -36,3 +36,4 @@ momentum = 0.99
batch_size = 32
valid_set = 'cuts/test_times_0'
+max_splits = 100
diff --git a/model/mlp.py b/model/mlp.py
index 05898a5..fc86b7b 100644
--- a/model/mlp.py
+++ b/model/mlp.py
@@ -51,14 +51,18 @@ class Stream(object):
self.config = config
def train(self, req_vars):
- stream = TaxiDataset('train')
- stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
-
valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',))
valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
+ stream = TaxiDataset('train')
+
+ if hasattr(self.config, 'use_cuts_for_trainig') and self.config.use_cuts_for_training:
+ stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
+ else:
+ stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))
+
stream = transformers.TaxiExcludeTrips(valid_trips_ids, stream)
- stream = transformers.TaxiGenerateSplits(stream, max_splits=1)
+ stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)
stream = transformers.TaxiAddDateTime(stream)
stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream)