diff options
38 files changed, 80 insertions, 83 deletions
diff --git a/config/bidirectional_1.py b/config/bidirectional_1.py index 8691357..35039d2 100644 --- a/config/bidirectional_1.py +++ b/config/bidirectional_1.py @@ -27,5 +27,4 @@ mlp_biases_init = Constant(0.01) batch_size = 20 batch_sort_size = 20 -valid_set = 'cuts/large_valid' max_splits = 100 diff --git a/config/bidirectional_tgtcls_1.py b/config/bidirectional_tgtcls_1.py index 4c9ed3e..88328e4 100644 --- a/config/bidirectional_tgtcls_1.py +++ b/config/bidirectional_tgtcls_1.py @@ -32,5 +32,4 @@ mlp_biases_init = Constant(0.01) batch_size = 20 batch_sort_size = 20 -valid_set = 'cuts/large_valid' max_splits = 100 diff --git a/config/dest_mlp_2_cs.py b/config/dest_mlp_2_cs.py index ca1ee39..d5e3092 100644 --- a/config/dest_mlp_2_cs.py +++ b/config/dest_mlp_2_cs.py @@ -23,5 +23,4 @@ learning_rate = 0.0001 momentum = 0.99 batch_size = 32 -valid_set = 'cuts/test_times_0' max_splits = 100 diff --git a/config/dest_mlp_2_cswdt.py b/config/dest_mlp_2_cswdt.py index 592378a..77f4bb3 100644 --- a/config/dest_mlp_2_cswdt.py +++ b/config/dest_mlp_2_cswdt.py @@ -27,5 +27,4 @@ learning_rate = 0.0001 momentum = 0.99 batch_size = 32 -valid_set = 'cuts/test_times_0' max_splits = 100 diff --git a/config/dest_mlp_2_noembed.py b/config/dest_mlp_2_noembed.py index d7582fe..94464fe 100644 --- a/config/dest_mlp_2_noembed.py +++ b/config/dest_mlp_2_noembed.py @@ -20,5 +20,4 @@ learning_rate = 0.0001 momentum = 0.99 batch_size = 32 -valid_set = 'cuts/test_times_0' max_splits = 100 diff --git a/config/dest_mlp_emb_only.py b/config/dest_mlp_emb_only.py index 76acdfa..921fc1d 100644 --- a/config/dest_mlp_emb_only.py +++ b/config/dest_mlp_emb_only.py @@ -27,5 +27,4 @@ learning_rate = 0.001 momentum = 0.9 batch_size = 100 -valid_set = 'cuts/test_times_0' max_splits = 100 diff --git a/config/dest_mlp_tgtcls_0_cs.py b/config/dest_mlp_tgtcls_0_cs.py index 684e653..2156dda 100644 --- a/config/dest_mlp_tgtcls_0_cs.py +++ b/config/dest_mlp_tgtcls_0_cs.py @@ -28,5 +28,4 @@ learning_rate = 0.0001 momentum = 0.99 batch_size = 32 -valid_set = 'cuts/test_times_0' max_splits = 100 diff --git a/config/dest_mlp_tgtcls_1_cs.py b/config/dest_mlp_tgtcls_1_cs.py index 1d28c1a..51c158b 100644 --- a/config/dest_mlp_tgtcls_1_cs.py +++ b/config/dest_mlp_tgtcls_1_cs.py @@ -28,5 +28,4 @@ learning_rate = 0.0001 momentum = 0.99 batch_size = 32 -valid_set = 'cuts/test_times_0' max_splits = 100 diff --git a/config/dest_mlp_tgtcls_1_cswdt.py b/config/dest_mlp_tgtcls_1_cswdt.py index 13abd90..5a66224 100644 --- a/config/dest_mlp_tgtcls_1_cswdt.py +++ b/config/dest_mlp_tgtcls_1_cswdt.py @@ -32,5 +32,4 @@ learning_rate = 0.0001 momentum = 0.99 batch_size = 32 -valid_set = 'cuts/test_times_0' max_splits = 100 diff --git a/config/dest_mlp_tgtcls_1_cswdtx.py b/config/dest_mlp_tgtcls_1_cswdtx.py index 1a39dfd..37bd52e 100644 --- a/config/dest_mlp_tgtcls_1_cswdtx.py +++ b/config/dest_mlp_tgtcls_1_cswdtx.py @@ -35,6 +35,3 @@ batch_size = 100 use_cuts_for_training = True max_splits = 1 - -valid_set = 'cuts/test_times_0' - diff --git a/config/dest_mlp_tgtcls_1_cswdtx_alexandre.py b/config/dest_mlp_tgtcls_1_cswdtx_alexandre.py index db69cd3..42adee7 100644 --- a/config/dest_mlp_tgtcls_1_cswdtx_alexandre.py +++ b/config/dest_mlp_tgtcls_1_cswdtx_alexandre.py @@ -34,5 +34,4 @@ step_rule = Momentum(learning_rate=0.01, momentum=0.9) batch_size = 200 -valid_set = 'cuts/test_times_0' max_splits = 100 diff --git a/config/dest_mlp_tgtcls_1_cswdtx_batchshuffle.py b/config/dest_mlp_tgtcls_1_cswdtx_batchshuffle.py index b816930..5795d89 100644 --- a/config/dest_mlp_tgtcls_1_cswdtx_batchshuffle.py +++ b/config/dest_mlp_tgtcls_1_cswdtx_batchshuffle.py @@ -36,5 +36,4 @@ batch_size = 200 shuffle_batch_size = 5000 -valid_set = 'cuts/test_times_0' max_splits = 100 diff --git a/config/dest_mlp_tgtcls_1_cswdtx_small.py b/config/dest_mlp_tgtcls_1_cswdtx_small.py index e6135ce..572f7df 100644 --- a/config/dest_mlp_tgtcls_1_cswdtx_small.py +++ b/config/dest_mlp_tgtcls_1_cswdtx_small.py @@ -34,5 +34,4 @@ step_rule = Momentum(learning_rate=0.01, momentum=0.9) batch_size = 200 -valid_set = 'cuts/test_times_0' max_splits = 100 diff --git a/config/dest_mlp_tgtcls_2_cswdtx_small.py b/config/dest_mlp_tgtcls_2_cswdtx_small.py index a4658b7..1e61293 100644 --- a/config/dest_mlp_tgtcls_2_cswdtx_small.py +++ b/config/dest_mlp_tgtcls_2_cswdtx_small.py @@ -34,5 +34,4 @@ step_rule = Momentum(learning_rate=0.01, momentum=0.9) batch_size = 200 -valid_set = 'cuts/test_times_0' max_splits = 100 diff --git a/config/joint_mlp_tgtcls_111_cswdtx.py b/config/joint_mlp_tgtcls_111_cswdtx.py index 83d3d11..89b22f6 100644 --- a/config/joint_mlp_tgtcls_111_cswdtx.py +++ b/config/joint_mlp_tgtcls_111_cswdtx.py @@ -50,5 +50,4 @@ learning_rate = 0.0001 momentum = 0.99 batch_size = 200 -valid_set = 'cuts/test_times_0' max_splits = 100 diff --git a/config/joint_mlp_tgtcls_111_cswdtx_bigger.py b/config/joint_mlp_tgtcls_111_cswdtx_bigger.py index 7ebe864..52b7b04 100644 --- a/config/joint_mlp_tgtcls_111_cswdtx_bigger.py +++ b/config/joint_mlp_tgtcls_111_cswdtx_bigger.py @@ -49,6 +49,4 @@ mlp_biases_init = Constant(0.01) # use adadelta, so no learning_rate or momentum batch_size = 200 -valid_set = 'cuts/test_times_0' - max_splits = 100 diff --git a/config/joint_mlp_tgtcls_111_cswdtx_bigger_dropout.py b/config/joint_mlp_tgtcls_111_cswdtx_bigger_dropout.py index e0448cc..26f7eba 100644 --- a/config/joint_mlp_tgtcls_111_cswdtx_bigger_dropout.py +++ b/config/joint_mlp_tgtcls_111_cswdtx_bigger_dropout.py @@ -55,6 +55,4 @@ dropout_inputs = VariableFilter(bricks=[Rectifier], name='output') # use adadelta, so no learning_rate or momentum batch_size = 200 -valid_set = 'cuts/test_times_0' - max_splits = 100 diff --git a/config/joint_mlp_tgtcls_111_cswdtx_noise_dout.py b/config/joint_mlp_tgtcls_111_cswdtx_noise_dout.py index fbc88a1..6dfca71 100644 --- a/config/joint_mlp_tgtcls_111_cswdtx_noise_dout.py +++ b/config/joint_mlp_tgtcls_111_cswdtx_noise_dout.py @@ -57,5 +57,4 @@ dropout_inputs = VariableFilter(bricks=[Rectifier], name='output') noise = 0.01 noise_inputs = VariableFilter(roles=[roles.PARAMETER]) -valid_set = 'cuts/test_times_0' max_splits = 100 diff --git a/config/joint_mlp_tgtcls_1_cswdtx.py b/config/joint_mlp_tgtcls_1_cswdtx.py index f53e41c..0a9bd32 100644 --- a/config/joint_mlp_tgtcls_1_cswdtx.py +++ b/config/joint_mlp_tgtcls_1_cswdtx.py @@ -50,5 +50,4 @@ learning_rate = 0.0001 momentum = 0.99 batch_size = 200 -valid_set = 'cuts/test_times_0' max_splits = 100 diff --git a/config/joint_mlp_tgtcls_1_cswdtx_bigger.py b/config/joint_mlp_tgtcls_1_cswdtx_bigger.py index 72d3c83..282cb42 100644 --- a/config/joint_mlp_tgtcls_1_cswdtx_bigger.py +++ b/config/joint_mlp_tgtcls_1_cswdtx_bigger.py @@ -50,5 +50,4 @@ learning_rate = 0.0001 momentum = 0.99 batch_size = 200 -valid_set = 'cuts/test_times_0' max_splits = 100 diff --git a/config/memory_network_1.py b/config/memory_network_1.py index 813c9d2..70b0f3e 100644 --- a/config/memory_network_1.py +++ b/config/memory_network_1.py @@ -37,7 +37,6 @@ embed_weights_init = IsotropicGaussian(0.001) batch_size = 32 -valid_set = 'cuts/test_times_0' max_splits = 1 num_cuts = 1000 diff --git a/config/memory_network_2.py b/config/memory_network_2.py index 92a9675..cdd2bc1 100644 --- a/config/memory_network_2.py +++ b/config/memory_network_2.py @@ -48,7 +48,6 @@ noise_inputs = VariableFilter(roles=[roles.PARAMETER]) batch_size = 512 -valid_set = 'cuts/test_times_0' max_splits = 1 num_cuts = 1000 diff --git a/config/memory_network_3.py b/config/memory_network_3.py index 66cddbc..aa1fecb 100644 --- a/config/memory_network_3.py +++ b/config/memory_network_3.py @@ -48,7 +48,6 @@ noise_inputs = VariableFilter(roles=[roles.PARAMETER]) batch_size = 512 -valid_set = 'cuts/test_times_0' max_splits = 1 num_cuts = 1000 diff --git a/config/memory_network_adeb.py b/config/memory_network_adeb.py index a9d6fef..1f0a271 100644 --- a/config/memory_network_adeb.py +++ b/config/memory_network_adeb.py @@ -38,7 +38,6 @@ embed_weights_init = IsotropicGaussian(0.001) step_rule = Momentum(learning_rate=0.001, momentum=0.9) batch_size = 32 -valid_set = 'cuts/test_times_0' max_splits = 1 num_cuts = 1000 diff --git a/config/memory_network_bidir.py b/config/memory_network_bidir.py index dc0824c..beba242 100644 --- a/config/memory_network_bidir.py +++ b/config/memory_network_bidir.py @@ -47,7 +47,6 @@ normalize_representation = True batch_size = 32 batch_sort_size = 20 -valid_set = 'cuts/test_times_0' max_splits = 100 num_cuts = 1000 diff --git a/config/rnn_1.py b/config/rnn_1.py index 6e148c4..0947bc5 100644 --- a/config/rnn_1.py +++ b/config/rnn_1.py @@ -30,4 +30,3 @@ biases_init = Constant(0.001) batch_size = 10 batch_sort_size = 10 -valid_set = 'cuts/test_times_0' diff --git a/config/rnn_lag_tgtcls_1.py b/config/rnn_lag_tgtcls_1.py index 7a41b70..01827bf 100644 --- a/config/rnn_lag_tgtcls_1.py +++ b/config/rnn_lag_tgtcls_1.py @@ -46,4 +46,3 @@ noise_inputs = VariableFilter(roles=[roles.PARAMETER]) batch_size = 10 batch_sort_size = 10 -valid_set = 'cuts/test_times_0' diff --git a/config/rnn_tgtcls_1.py b/config/rnn_tgtcls_1.py index 3204559..ed9b654 100644 --- a/config/rnn_tgtcls_1.py +++ b/config/rnn_tgtcls_1.py @@ -34,4 +34,3 @@ biases_init = Constant(0.001) batch_size = 10 batch_sort_size = 10 -valid_set = 'cuts/test_times_0' diff --git a/config/time_mlp_1.py b/config/time_mlp_1.py index 4c2bffb..805d60b 100644 --- a/config/time_mlp_1.py +++ b/config/time_mlp_1.py @@ -23,5 +23,4 @@ learning_rate = 0.00001 momentum = 0.99 batch_size = 32 -valid_set = 'cuts/test_times_0' max_splits = 100 diff --git a/config/time_mlp_2_cswdtx.py b/config/time_mlp_2_cswdtx.py index 30dda87..d69da09 100644 --- a/config/time_mlp_2_cswdtx.py +++ b/config/time_mlp_2_cswdtx.py @@ -30,5 +30,4 @@ learning_rate = 0.00001 momentum = 0.99 batch_size = 32 -valid_set = 'cuts/test_times_0' max_splits = 100 diff --git a/config/time_mlp_tgtcls_2_cswdtx.py b/config/time_mlp_tgtcls_2_cswdtx.py index 809a808..c6be124 100644 --- a/config/time_mlp_tgtcls_2_cswdtx.py +++ b/config/time_mlp_tgtcls_2_cswdtx.py @@ -33,5 +33,4 @@ learning_rate = 0.0001 momentum = 0.99 batch_size = 32 -valid_set = 'cuts/test_times_0' max_splits = 100 diff --git a/data/__init__.py b/data/__init__.py index 2121033..9d01d2a 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -1,4 +1,5 @@ import os +import sys import h5py import numpy @@ -8,23 +9,46 @@ path = os.environ.get('TAXI_PATH', '/data/lisatmp3/auvolat/taxikaggle') Polyline = h5py.special_dtype(vlen=numpy.float32) -# `wc -l test.csv` - 1 # Minus 1 to ignore the header -test_size = 320 - -# `wc -l train.csv` - 1 -train_size = 1710670 - # `wc -l metaData_taxistandsID_name_GPSlocation.csv` stands_size = 64 # include 0 ("no origin_stands") # `cut -d, -f 5 train.csv test.csv | sort -u | wc -l` - 1 taxi_id_size = 448 -# `cut -d, -f 3 train.csv test.csv | sort -u | wc -l` - 2 -origin_call_size = 57125 # include 0 ("no origin_call") - -# As printed by csv_to_hdf5.py -origin_call_train_size = 57106 - train_gps_mean = numpy.array([41.1573, -8.61612], dtype=numpy.float32) train_gps_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=numpy.float32)) + +tvt = '--tvt' in sys.argv + +if tvt: + test_size = 19770 + valid_size = 19427 + train_size = 1671473 + + origin_call_size = 57106 + origin_call_train_size = 57106 + + valid_set = 'valid' + valid_ds = 'tvt.hdf5' + traintest_ds = 'tvt.hdf5' + +else: + # `wc -l test.csv` - 1 # Minus 1 to ignore the header + test_size = 320 + + # `wc -l train.csv` - 1 + train_size = 1710670 + + # `cut -d, -f 3 train.csv test.csv | sort -u | wc -l` - 2 + origin_call_size = 57125 # include 0 ("no origin_call") + + # As printed by csv_to_hdf5.py + origin_call_train_size = 57106 + + if '--largevalid' in sys.argv: + valid_set = 'cuts/large_valid' + else: + valid_set = 'cuts/test_times_0' + + valid_ds = 'valid.hdf5' + traintest_ds = 'data.hdf5' diff --git a/data/transformers.py b/data/transformers.py index c2eb97e..b3a8486 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -70,7 +70,9 @@ class TaxiGenerateSplits(Transformer): def __init__(self, data_stream, max_splits=-1): super(TaxiGenerateSplits, self).__init__(data_stream) - self.sources = data_stream.sources + ('destination_latitude', 'destination_longitude', 'travel_time') + self.sources = data_stream.sources + if not data.tvt: + self.sources += ('destination_latitude', 'destination_longitude', 'travel_time') self.max_splits = max_splits self.data = None self.splits = [] @@ -100,12 +102,15 @@ class TaxiGenerateSplits(Transformer): r[self.id_latitude] = numpy.array(r[self.id_latitude][:n], dtype=theano.config.floatX) r[self.id_longitude] = numpy.array(r[self.id_longitude][:n], dtype=theano.config.floatX) - dlat = numpy.float32(self.data[self.id_latitude][-1]) - dlon = numpy.float32(self.data[self.id_longitude][-1]) - ttime = numpy.int32(15 * (len(self.data[self.id_longitude]) - 1)) - - return tuple(r + [dlat, dlon, ttime]) + r = tuple(r) + if data.tvt: + return r + else: + dlat = numpy.float32(self.data[self.id_latitude][-1]) + dlon = numpy.float32(self.data[self.id_longitude][-1]) + ttime = numpy.int32(15 * (len(self.data[self.id_longitude]) - 1)) + return r + (dlat, dlon, ttime) class _taxi_add_first_last_len_helper(object): def __init__(self, k, id_latitude, id_longitude): diff --git a/model/bidirectional.py b/model/bidirectional.py index 483ea54..ba6440a 100644 --- a/model/bidirectional.py +++ b/model/bidirectional.py @@ -98,17 +98,18 @@ class Stream(object): self.config = config def train(self, req_vars): - valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',)) - valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] - - stream = TaxiDataset('train') + stream = TaxiDataset('train', data.traintest_ds) if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training: stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) else: stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) - stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) + if not data.tvt: + valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) + valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] + stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) + stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) if hasattr(self.config, 'shuffle_batch_size'): @@ -128,7 +129,7 @@ class Stream(object): return stream def valid(self, req_vars): - stream = TaxiStream(self.config.valid_set, 'valid.hdf5') + stream = TaxiStream(data.valid_set, data.valid_ds) stream = transformers.taxi_add_datetime(stream) stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) @@ -139,7 +140,7 @@ class Stream(object): return stream def test(self, req_vars): - stream = TaxiStream('test') + stream = TaxiStream('test', data.traintest_ds) stream = transformers.taxi_add_datetime(stream) stream = transformers.taxi_remove_test_only_clients(stream) diff --git a/model/memory_network.py b/model/memory_network.py index 14d1e07..e7ba51c 100644 --- a/model/memory_network.py +++ b/model/memory_network.py @@ -53,20 +53,20 @@ class StreamBase(object): @property def valid_dataset(self): - return TaxiDataset(self.config.valid_set, 'valid.hdf5') + return TaxiDataset(data.valid_set, data.valid_ds) @property def valid_trips_ids(self): - valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',)) + valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) return valid.get_data(None, slice(0, valid.num_examples))[0] @property def train_dataset(self): - return TaxiDataset('train') + return TaxiDataset('train', data.traintest_ds) @property def test_dataset(self): - return TaxiDataset('test') + return TaxiDataset('test', data.traintest_ds) class StreamSimple(StreamBase): @@ -96,7 +96,8 @@ class StreamSimple(StreamBase): prefix_stream = DataStream(self.train_dataset, iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples)) - prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids) + if not data.tvt: + prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids) prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream) prefix_stream = transformers.TaxiGenerateSplits(prefix_stream, max_splits=self.config.max_splits) diff --git a/model/mlp.py b/model/mlp.py index d24b2cc..0336be1 100644 --- a/model/mlp.py +++ b/model/mlp.py @@ -63,17 +63,18 @@ class Stream(object): self.config = config def train(self, req_vars): - valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',)) - valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] - - stream = TaxiDataset('train') + stream = TaxiDataset('train', data.traintest_ds) if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training: stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) else: stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) - stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) + if not data.tvt: + valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) + valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] + stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) + stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) if hasattr(self.config, 'shuffle_batch_size'): @@ -92,7 +93,7 @@ class Stream(object): return stream def valid(self, req_vars): - stream = TaxiStream(self.config.valid_set, 'valid.hdf5') + stream = TaxiStream(data.valid_set, data.valid_ds) stream = transformers.taxi_add_datetime(stream) stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) @@ -100,7 +101,7 @@ class Stream(object): return Batch(stream, iteration_scheme=ConstantScheme(1000)) def test(self, req_vars): - stream = TaxiStream('test') + stream = TaxiStream('test', data.traintest_ds) stream = transformers.taxi_add_datetime(stream) stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) diff --git a/model/rnn.py b/model/rnn.py index 70d054a..7bdba3b 100644 --- a/model/rnn.py +++ b/model/rnn.py @@ -146,15 +146,15 @@ class Stream(object): self.config = config def train(self, req_vars): - valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',)) - valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] - - stream = TaxiDataset('train') + stream = TaxiDataset('train', data.traintest_ds) stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) - stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) + if not data.tvt: + valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) + valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] + stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) + stream = transformers.TaxiExcludeEmptyTrips(stream) stream = transformers.taxi_add_datetime(stream) - stream = transformers.add_destination(stream) stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) stream = transformers.balanced_batch(stream, key='latitude', batch_size=self.config.batch_size, batch_sort_size=self.config.batch_sort_size) @@ -163,7 +163,7 @@ class Stream(object): return stream def valid(self, req_vars): - stream = TaxiStream(self.config.valid_set, 'valid.hdf5') + stream = TaxiStream(data.valid_set, data.valid_ds) stream = transformers.taxi_add_datetime(stream) stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) @@ -173,7 +173,7 @@ class Stream(object): return stream def test(self, req_vars): - stream = TaxiStream('test') + stream = TaxiStream('test', data.traintest_ds) stream = transformers.taxi_add_datetime(stream) stream = transformers.taxi_remove_test_only_clients(stream) stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) @@ -37,10 +37,10 @@ from ext_test import RunOnTest logger = logging.getLogger(__name__) if __name__ == "__main__": - if len(sys.argv) != 2: - print >> sys.stderr, 'Usage: %s config' % sys.argv[0] + if len(sys.argv) < 2 or len(sys.argv) > 3: + print >> sys.stderr, 'Usage: %s [--tvt | --largevalid] config' % sys.argv[0] sys.exit(1) - model_name = sys.argv[1] + model_name = sys.argv[-1] config = importlib.import_module('.%s' % model_name, 'config') logger.info('# Configuration: %s' % config.__name__) |