diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-08 13:59:44 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-08 13:59:44 -0400 |
commit | 0ecac7973fd02f44af9c8bc5765f7c159c94b23a (patch) | |
tree | ec28ec6253d84b3c55f8fa88adce4e720d26bd07 /train.py | |
parent | 1ffd1fc355f6fddcb6cd3d93c0df58513d064472 (diff) | |
download | taxi-0ecac7973fd02f44af9c8bc5765f7c159c94b23a.tar.gz taxi-0ecac7973fd02f44af9c8bc5765f7c159c94b23a.zip |
Fusion AddFirstK and AddLastK transformers, and add 'input_time' field.
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 10 |
1 files changed, 4 insertions, 6 deletions
@@ -44,8 +44,7 @@ def setup_train_stream(req_vars, valid_trips_ids): train = transformers.TaxiGenerateSplits(train, max_splits=100) train = transformers.TaxiAddDateTime(train) - train = transformers.TaxiAddFirstK(config.n_begin_end_pts, train) - train = transformers.TaxiAddLastK(config.n_begin_end_pts, train) + train = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, train) train = transformers.Select(train, tuple(req_vars)) train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size)) @@ -56,8 +55,7 @@ def setup_valid_stream(req_vars): valid = TaxiStream(config.valid_set, 'valid.hdf5') valid = transformers.TaxiAddDateTime(valid) - valid = transformers.TaxiAddFirstK(config.n_begin_end_pts, valid) - valid = transformers.TaxiAddLastK(config.n_begin_end_pts, valid) + valid = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, valid) valid = transformers.Select(valid, tuple(req_vars)) valid_stream = Batch(valid, iteration_scheme=ConstantScheme(1000)) @@ -68,8 +66,8 @@ def setup_test_stream(req_vars): test = TaxiStream('test') test = transformers.TaxiAddDateTime(test) - test = transformers.TaxiAddFirstK(config.n_begin_end_pts, test) - test = transformers.TaxiAddLastK(config.n_begin_end_pts, test) + test = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, test) + test = transformers.TaxiAddLast(config.n_begin_end_pts, test) test = transformers.Select(test, tuple(req_vars)) test_stream = Batch(test, iteration_scheme=ConstantScheme(1000)) |