aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py18
1 files changed, 11 insertions, 7 deletions
diff --git a/train.py b/train.py
index 238803a..2c9522e 100644
--- a/train.py
+++ b/train.py
@@ -49,11 +49,13 @@ def setup_train_stream(req_vars):
subset=slice(0, data.dataset_size),
load_in_memory=True)
train = DataStream(train, iteration_scheme=SequentialExampleScheme(data.dataset_size - config.n_valid))
- train = transformers.filter_out_trips(data.valid_trips, train)
+
+ train = transformers.TaxiExcludeTrips(data.valid_trips, train)
train = transformers.TaxiGenerateSplits(train, max_splits=100)
- train = transformers.add_first_k(config.n_begin_end_pts, train)
- train = transformers.add_last_k(config.n_begin_end_pts, train)
+ train = transformers.TaxiAddDateTime(train)
+ train = transformers.TaxiAddFirstK(config.n_begin_end_pts, train)
+ train = transformers.TaxiAddLastK(config.n_begin_end_pts, train)
train = transformers.Select(train, tuple(req_vars))
train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size))
@@ -63,8 +65,9 @@ def setup_train_stream(req_vars):
def setup_valid_stream(req_vars):
valid = DataStream(data.valid_data)
- valid = transformers.add_first_k(config.n_begin_end_pts, valid)
- valid = transformers.add_last_k(config.n_begin_end_pts, valid)
+ valid = transformers.TaxiAddDateTime(valid)
+ valid = transformers.TaxiAddFirstK(config.n_begin_end_pts, valid)
+ valid = transformers.TaxiAddLastK(config.n_begin_end_pts, valid)
valid = transformers.Select(valid, tuple(req_vars))
valid_stream = Batch(valid, iteration_scheme=ConstantScheme(1000))
@@ -74,8 +77,9 @@ def setup_valid_stream(req_vars):
def setup_test_stream(req_vars):
test = DataStream(data.test_data)
- test = transformers.add_first_k(config.n_begin_end_pts, test)
- test = transformers.add_last_k(config.n_begin_end_pts, test)
+ test = transformers.TaxiAddDateTime(test)
+ test = transformers.TaxiAddFirstK(config.n_begin_end_pts, test)
+ test = transformers.TaxiAddLastK(config.n_begin_end_pts, test)
test = transformers.Select(test, tuple(req_vars))
test_stream = Batch(test, iteration_scheme=ConstantScheme(1000))