diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-04 13:15:23 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-04 13:15:23 -0400 |
commit | de76aae44b6c0cbe9ab42c7ae215c3ae9e4e4055 (patch) | |
tree | 09c09a12861f0f6826cd33e3b77eba9a07076c49 /model.py | |
parent | 43e106e6630030dd34813295fe1d07bb86025402 (diff) | |
download | taxi-de76aae44b6c0cbe9ab42c7ae215c3ae9e4e4055.tar.gz taxi-de76aae44b6c0cbe9ab42c7ae215c3ae9e4e4055.zip |
Add TaxiGenerateSplits
Diffstat (limited to 'model.py')
-rw-r--r-- | model.py | 28 |
1 files changed, 15 insertions, 13 deletions
@@ -44,35 +44,35 @@ if __name__ == "__main__": sys.exit(1) config = importlib.import_module(sys.argv[1]) -def setup_stream(): + +def setup_train_stream(): # Load the training and test data train = H5PYDataset(data.H5DATA_PATH, which_set='train', - subset=slice(0, data.dataset_size - config.n_valid), + subset=slice(0, data.dataset_size), load_in_memory=True) train = DataStream(train, iteration_scheme=SequentialExampleScheme(data.dataset_size - config.n_valid)) + train = transformers.filter_out_trips(data.valid_trips, train) + train = transformers.TaxiGenerateSplits(train) train = transformers.add_first_k(config.n_begin_end_pts, train) - train = transformers.add_random_k(config.n_begin_end_pts, train) - train = transformers.add_destination(train) + train = transformers.add_last_k(config.n_begin_end_pts, train) train = transformers.Select(train, ('origin_stand', 'origin_call', 'first_k_latitude', 'last_k_latitude', 'first_k_longitude', 'last_k_longitude', 'destination_latitude', 'destination_longitude')) train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size)) - valid = H5PYDataset(data.H5DATA_PATH, - which_set='train', - subset=slice(data.dataset_size - config.n_valid, data.dataset_size), - load_in_memory=True) - valid = DataStream(valid, iteration_scheme=SequentialExampleScheme(config.n_valid)) + return train_stream + +def setup_valid_stream(): + valid = DataStream(data.valid_data) valid = transformers.add_first_k(config.n_begin_end_pts, valid) - valid = transformers.add_random_k(config.n_begin_end_pts, valid) - valid = transformers.add_destination(valid) + valid = transformers.add_last_k(config.n_begin_end_pts, valid) valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k_latitude', 'last_k_latitude', 'first_k_longitude', 'last_k_longitude', 'destination_latitude', 'destination_longitude')) valid_stream = Batch(valid, iteration_scheme=ConstantScheme(1000)) - return (train_stream, valid_stream) + return valid_stream def setup_test_stream(): test = data.test_data @@ -86,6 +86,7 @@ def setup_test_stream(): return test_stream + def main(): # The input and the targets x_firstk_latitude = (tensor.matrix('first_k_latitude') - data.porto_center[0]) / data.data_std[0] @@ -146,7 +147,8 @@ def main(): stand_embed_table.initialize() mlp.initialize() - (train_stream, valid_stream) = setup_stream() + train_stream = setup_train_stream() + valid_stream = setup_valid_stream() # Training cg = ComputationGraph(cost) |