aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-08 13:59:44 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-08 13:59:44 -0400
commit0ecac7973fd02f44af9c8bc5765f7c159c94b23a (patch)
treeec28ec6253d84b3c55f8fa88adce4e720d26bd07 /train.py
parent1ffd1fc355f6fddcb6cd3d93c0df58513d064472 (diff)
downloadtaxi-0ecac7973fd02f44af9c8bc5765f7c159c94b23a.tar.gz
taxi-0ecac7973fd02f44af9c8bc5765f7c159c94b23a.zip
Fusion AddFirstK and AddLastK transformers, and add 'input_time' field.
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py10
1 files changed, 4 insertions, 6 deletions
diff --git a/train.py b/train.py
index 4e2b983..8d9f4ad 100755
--- a/train.py
+++ b/train.py
@@ -44,8 +44,7 @@ def setup_train_stream(req_vars, valid_trips_ids):
train = transformers.TaxiGenerateSplits(train, max_splits=100)
train = transformers.TaxiAddDateTime(train)
- train = transformers.TaxiAddFirstK(config.n_begin_end_pts, train)
- train = transformers.TaxiAddLastK(config.n_begin_end_pts, train)
+ train = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, train)
train = transformers.Select(train, tuple(req_vars))
train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size))
@@ -56,8 +55,7 @@ def setup_valid_stream(req_vars):
valid = TaxiStream(config.valid_set, 'valid.hdf5')
valid = transformers.TaxiAddDateTime(valid)
- valid = transformers.TaxiAddFirstK(config.n_begin_end_pts, valid)
- valid = transformers.TaxiAddLastK(config.n_begin_end_pts, valid)
+ valid = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, valid)
valid = transformers.Select(valid, tuple(req_vars))
valid_stream = Batch(valid, iteration_scheme=ConstantScheme(1000))
@@ -68,8 +66,8 @@ def setup_test_stream(req_vars):
test = TaxiStream('test')
test = transformers.TaxiAddDateTime(test)
- test = transformers.TaxiAddFirstK(config.n_begin_end_pts, test)
- test = transformers.TaxiAddLastK(config.n_begin_end_pts, test)
+ test = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, test)
+ test = transformers.TaxiAddLast(config.n_begin_end_pts, test)
test = transformers.Select(test, tuple(req_vars))
test_stream = Batch(test, iteration_scheme=ConstantScheme(1000))