aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-05 11:36:59 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-05 11:36:59 -0400
commitab1076e00d6a92120e46d4a0085911b4425a0d60 (patch)
tree83bf857eb61b2df332362b95600686f8d90f508c /train.py
parent5b496677ea1db59a6718e5c9b2958177c76cb25f (diff)
downloadtaxi-ab1076e00d6a92120e46d4a0085911b4425a0d60.tar.gz
taxi-ab1076e00d6a92120e46d4a0085911b4425a0d60.zip
Add date/time transformer and new model that uses it
Diffstat (limited to 'train.py')
-rw-r--r--train.py18
1 files changed, 11 insertions, 7 deletions
diff --git a/train.py b/train.py
index 238803a..2c9522e 100644
--- a/train.py
+++ b/train.py
@@ -49,11 +49,13 @@ def setup_train_stream(req_vars):
subset=slice(0, data.dataset_size),
load_in_memory=True)
train = DataStream(train, iteration_scheme=SequentialExampleScheme(data.dataset_size - config.n_valid))
- train = transformers.filter_out_trips(data.valid_trips, train)
+
+ train = transformers.TaxiExcludeTrips(data.valid_trips, train)
train = transformers.TaxiGenerateSplits(train, max_splits=100)
- train = transformers.add_first_k(config.n_begin_end_pts, train)
- train = transformers.add_last_k(config.n_begin_end_pts, train)
+ train = transformers.TaxiAddDateTime(train)
+ train = transformers.TaxiAddFirstK(config.n_begin_end_pts, train)
+ train = transformers.TaxiAddLastK(config.n_begin_end_pts, train)
train = transformers.Select(train, tuple(req_vars))
train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size))
@@ -63,8 +65,9 @@ def setup_train_stream(req_vars):
def setup_valid_stream(req_vars):
valid = DataStream(data.valid_data)
- valid = transformers.add_first_k(config.n_begin_end_pts, valid)
- valid = transformers.add_last_k(config.n_begin_end_pts, valid)
+ valid = transformers.TaxiAddDateTime(valid)
+ valid = transformers.TaxiAddFirstK(config.n_begin_end_pts, valid)
+ valid = transformers.TaxiAddLastK(config.n_begin_end_pts, valid)
valid = transformers.Select(valid, tuple(req_vars))
valid_stream = Batch(valid, iteration_scheme=ConstantScheme(1000))
@@ -74,8 +77,9 @@ def setup_valid_stream(req_vars):
def setup_test_stream(req_vars):
test = DataStream(data.test_data)
- test = transformers.add_first_k(config.n_begin_end_pts, test)
- test = transformers.add_last_k(config.n_begin_end_pts, test)
+ test = transformers.TaxiAddDateTime(test)
+ test = transformers.TaxiAddFirstK(config.n_begin_end_pts, test)
+ test = transformers.TaxiAddLastK(config.n_begin_end_pts, test)
test = transformers.Select(test, tuple(req_vars))
test_stream = Batch(test, iteration_scheme=ConstantScheme(1000))