aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py10
1 files changed, 4 insertions, 6 deletions
diff --git a/train.py b/train.py
index 4e2b983..8d9f4ad 100755
--- a/train.py
+++ b/train.py
@@ -44,8 +44,7 @@ def setup_train_stream(req_vars, valid_trips_ids):
train = transformers.TaxiGenerateSplits(train, max_splits=100)
train = transformers.TaxiAddDateTime(train)
- train = transformers.TaxiAddFirstK(config.n_begin_end_pts, train)
- train = transformers.TaxiAddLastK(config.n_begin_end_pts, train)
+ train = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, train)
train = transformers.Select(train, tuple(req_vars))
train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size))
@@ -56,8 +55,7 @@ def setup_valid_stream(req_vars):
valid = TaxiStream(config.valid_set, 'valid.hdf5')
valid = transformers.TaxiAddDateTime(valid)
- valid = transformers.TaxiAddFirstK(config.n_begin_end_pts, valid)
- valid = transformers.TaxiAddLastK(config.n_begin_end_pts, valid)
+ valid = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, valid)
valid = transformers.Select(valid, tuple(req_vars))
valid_stream = Batch(valid, iteration_scheme=ConstantScheme(1000))
@@ -68,8 +66,8 @@ def setup_test_stream(req_vars):
test = TaxiStream('test')
test = transformers.TaxiAddDateTime(test)
- test = transformers.TaxiAddFirstK(config.n_begin_end_pts, test)
- test = transformers.TaxiAddLastK(config.n_begin_end_pts, test)
+ test = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, test)
+ test = transformers.TaxiAddLast(config.n_begin_end_pts, test)
test = transformers.Select(test, tuple(req_vars))
test_stream = Batch(test, iteration_scheme=ConstantScheme(1000))