diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-05 11:36:59 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-05 11:36:59 -0400 |
commit | ab1076e00d6a92120e46d4a0085911b4425a0d60 (patch) | |
tree | 83bf857eb61b2df332362b95600686f8d90f508c /train.py | |
parent | 5b496677ea1db59a6718e5c9b2958177c76cb25f (diff) | |
download | taxi-ab1076e00d6a92120e46d4a0085911b4425a0d60.tar.gz taxi-ab1076e00d6a92120e46d4a0085911b4425a0d60.zip |
Add date/time transformer and new model that uses it
Diffstat (limited to 'train.py')
-rw-r--r-- | train.py | 18 |
1 files changed, 11 insertions, 7 deletions
@@ -49,11 +49,13 @@ def setup_train_stream(req_vars): subset=slice(0, data.dataset_size), load_in_memory=True) train = DataStream(train, iteration_scheme=SequentialExampleScheme(data.dataset_size - config.n_valid)) - train = transformers.filter_out_trips(data.valid_trips, train) + + train = transformers.TaxiExcludeTrips(data.valid_trips, train) train = transformers.TaxiGenerateSplits(train, max_splits=100) - train = transformers.add_first_k(config.n_begin_end_pts, train) - train = transformers.add_last_k(config.n_begin_end_pts, train) + train = transformers.TaxiAddDateTime(train) + train = transformers.TaxiAddFirstK(config.n_begin_end_pts, train) + train = transformers.TaxiAddLastK(config.n_begin_end_pts, train) train = transformers.Select(train, tuple(req_vars)) train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size)) @@ -63,8 +65,9 @@ def setup_train_stream(req_vars): def setup_valid_stream(req_vars): valid = DataStream(data.valid_data) - valid = transformers.add_first_k(config.n_begin_end_pts, valid) - valid = transformers.add_last_k(config.n_begin_end_pts, valid) + valid = transformers.TaxiAddDateTime(valid) + valid = transformers.TaxiAddFirstK(config.n_begin_end_pts, valid) + valid = transformers.TaxiAddLastK(config.n_begin_end_pts, valid) valid = transformers.Select(valid, tuple(req_vars)) valid_stream = Batch(valid, iteration_scheme=ConstantScheme(1000)) @@ -74,8 +77,9 @@ def setup_valid_stream(req_vars): def setup_test_stream(req_vars): test = DataStream(data.test_data) - test = transformers.add_first_k(config.n_begin_end_pts, test) - test = transformers.add_last_k(config.n_begin_end_pts, test) + test = transformers.TaxiAddDateTime(test) + test = transformers.TaxiAddFirstK(config.n_begin_end_pts, test) + test = transformers.TaxiAddLastK(config.n_begin_end_pts, test) test = transformers.Select(test, tuple(req_vars)) test_stream = Batch(test, iteration_scheme=ConstantScheme(1000)) |