aboutsummaryrefslogtreecommitdiff
path: root/model
diff options
context:
space:
mode:
Diffstat (limited to 'model')
-rw-r--r--model/bidirectional.py15
-rw-r--r--model/memory_network.py11
-rw-r--r--model/mlp.py15
-rw-r--r--model/rnn.py16
4 files changed, 30 insertions, 27 deletions
diff --git a/model/bidirectional.py b/model/bidirectional.py
index 483ea54..ba6440a 100644
--- a/model/bidirectional.py
+++ b/model/bidirectional.py
@@ -98,17 +98,18 @@ class Stream(object):
self.config = config
def train(self, req_vars):
- valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',))
- valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
-
- stream = TaxiDataset('train')
+ stream = TaxiDataset('train', data.traintest_ds)
if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training:
stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
else:
stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))
- stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
+ if not data.tvt:
+ valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',))
+ valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
+ stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
+
stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)
if hasattr(self.config, 'shuffle_batch_size'):
@@ -128,7 +129,7 @@ class Stream(object):
return stream
def valid(self, req_vars):
- stream = TaxiStream(self.config.valid_set, 'valid.hdf5')
+ stream = TaxiStream(data.valid_set, data.valid_ds)
stream = transformers.taxi_add_datetime(stream)
stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
@@ -139,7 +140,7 @@ class Stream(object):
return stream
def test(self, req_vars):
- stream = TaxiStream('test')
+ stream = TaxiStream('test', data.traintest_ds)
stream = transformers.taxi_add_datetime(stream)
stream = transformers.taxi_remove_test_only_clients(stream)
diff --git a/model/memory_network.py b/model/memory_network.py
index 14d1e07..e7ba51c 100644
--- a/model/memory_network.py
+++ b/model/memory_network.py
@@ -53,20 +53,20 @@ class StreamBase(object):
@property
def valid_dataset(self):
- return TaxiDataset(self.config.valid_set, 'valid.hdf5')
+ return TaxiDataset(data.valid_set, data.valid_ds)
@property
def valid_trips_ids(self):
- valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',))
+ valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',))
return valid.get_data(None, slice(0, valid.num_examples))[0]
@property
def train_dataset(self):
- return TaxiDataset('train')
+ return TaxiDataset('train', data.traintest_ds)
@property
def test_dataset(self):
- return TaxiDataset('test')
+ return TaxiDataset('test', data.traintest_ds)
class StreamSimple(StreamBase):
@@ -96,7 +96,8 @@ class StreamSimple(StreamBase):
prefix_stream = DataStream(self.train_dataset,
iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples))
- prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids)
+ if not data.tvt:
+ prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids)
prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
prefix_stream = transformers.TaxiGenerateSplits(prefix_stream,
max_splits=self.config.max_splits)
diff --git a/model/mlp.py b/model/mlp.py
index d24b2cc..0336be1 100644
--- a/model/mlp.py
+++ b/model/mlp.py
@@ -63,17 +63,18 @@ class Stream(object):
self.config = config
def train(self, req_vars):
- valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',))
- valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
-
- stream = TaxiDataset('train')
+ stream = TaxiDataset('train', data.traintest_ds)
if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training:
stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
else:
stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))
- stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
+ if not data.tvt:
+ valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',))
+ valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
+ stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
+
stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)
if hasattr(self.config, 'shuffle_batch_size'):
@@ -92,7 +93,7 @@ class Stream(object):
return stream
def valid(self, req_vars):
- stream = TaxiStream(self.config.valid_set, 'valid.hdf5')
+ stream = TaxiStream(data.valid_set, data.valid_ds)
stream = transformers.taxi_add_datetime(stream)
stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
@@ -100,7 +101,7 @@ class Stream(object):
return Batch(stream, iteration_scheme=ConstantScheme(1000))
def test(self, req_vars):
- stream = TaxiStream('test')
+ stream = TaxiStream('test', data.traintest_ds)
stream = transformers.taxi_add_datetime(stream)
stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
diff --git a/model/rnn.py b/model/rnn.py
index 70d054a..7bdba3b 100644
--- a/model/rnn.py
+++ b/model/rnn.py
@@ -146,15 +146,15 @@ class Stream(object):
self.config = config
def train(self, req_vars):
- valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',))
- valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
-
- stream = TaxiDataset('train')
+ stream = TaxiDataset('train', data.traintest_ds)
stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))
- stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
+ if not data.tvt:
+ valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',))
+ valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
+ stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
+
stream = transformers.TaxiExcludeEmptyTrips(stream)
stream = transformers.taxi_add_datetime(stream)
- stream = transformers.add_destination(stream)
stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
stream = transformers.balanced_batch(stream, key='latitude', batch_size=self.config.batch_size, batch_sort_size=self.config.batch_sort_size)
@@ -163,7 +163,7 @@ class Stream(object):
return stream
def valid(self, req_vars):
- stream = TaxiStream(self.config.valid_set, 'valid.hdf5')
+ stream = TaxiStream(data.valid_set, data.valid_ds)
stream = transformers.taxi_add_datetime(stream)
stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
@@ -173,7 +173,7 @@ class Stream(object):
return stream
def test(self, req_vars):
- stream = TaxiStream('test')
+ stream = TaxiStream('test', data.traintest_ds)
stream = transformers.taxi_add_datetime(stream)
stream = transformers.taxi_remove_test_only_clients(stream)
stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))