diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-29 19:09:37 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-29 19:09:37 -0400 |
commit | fe704521e6cb4d7f32414b55044b6e2240524bf5 (patch) | |
tree | c6b02de77fc274a02f135ad6a71b586d0cb591a0 /model.py | |
parent | e28390de61b23882f6e1069d565a2137825c2662 (diff) | |
download | taxi-fe704521e6cb4d7f32414b55044b6e2240524bf5.tar.gz taxi-fe704521e6cb4d7f32414b55044b6e2240524bf5.zip |
Fix CSV import (partially)
Diffstat (limited to 'model.py')
-rw-r--r-- | model.py | 36 |
1 files changed, 21 insertions, 15 deletions
@@ -29,7 +29,7 @@ from fuel.schemes import ConstantScheme, SequentialExampleScheme from blocks.algorithms import GradientDescent, Scale, AdaDelta, Momentum from blocks.graph import ComputationGraph from blocks.main_loop import MainLoop -from blocks.extensions import Printing +from blocks.extensions import Printing, FinishAfter from blocks.extensions.saveload import Dump, LoadFromDump, Checkpoint from blocks.extensions.monitoring import DataStreamMonitoring @@ -48,9 +48,9 @@ def setup_stream(): # Load the training and test data train = H5PYDataset('/data/lisatmp3/simonet/taxi/data.hdf5', which_set='train', - subset=slice(0, config.train_size - config.n_valid), + subset=slice(0, data.dataset_size - config.n_valid), load_in_memory=True) - train = DataStream(train, iteration_scheme=SequentialExampleScheme(config.train_size - config.n_valid)) + train = DataStream(train, iteration_scheme=SequentialExampleScheme(data.dataset_size - config.n_valid)) train = transformers.add_first_k(config.n_begin_end_pts, train) train = transformers.add_random_k(config.n_begin_end_pts, train) train = transformers.add_destination(train) @@ -61,7 +61,7 @@ def setup_stream(): valid = H5PYDataset('/data/lisatmp3/simonet/taxi/data.hdf5', which_set='train', - subset=slice(config.train_size - config.n_valid, config.train_size), + subset=slice(data.dataset_size - config.n_valid, data.dataset_size), load_in_memory=True) valid = DataStream(valid, iteration_scheme=SequentialExampleScheme(config.n_valid)) valid = transformers.add_first_k(config.n_begin_end_pts, valid) @@ -74,6 +74,18 @@ def setup_stream(): return (train_stream, valid_stream) +def setup_test_stream(): + test = data.test_data + + test = DataStream(test) + test = transformers.add_first_k(config.n_begin_end_pts, test) + test = transformers.add_last_k(config.n_begin_end_pts, test) + test = transformers.Select(test, ('trip_id', 'origin_stand', 'origin_call', 'first_k_latitude', + 'last_k_latitude', 'first_k_longitude', 'last_k_longitude')) + test_stream = Batch(test, iteration_scheme=ConstantScheme(1000)) + + return test_stream + def main(): # The input and the targets x_firstk_latitude = (tensor.matrix('first_k_latitude') - data.porto_center[0]) / data.data_std[0] @@ -94,8 +106,8 @@ def main(): # x_lastk_longitude = theano.printing.Print("x_lastk_longitude")(x_lastk_longitude) # Define the model - client_embed_table = LookupTable(length=config.n_clients+1, dim=config.dim_embed, name='client_lookup') - stand_embed_table = LookupTable(length=config.n_stands+1, dim=config.dim_embed, name='stand_lookup') + client_embed_table = LookupTable(length=data.n_clients+1, dim=config.dim_embed, name='client_lookup') + stand_embed_table = LookupTable(length=data.n_stands+1, dim=config.dim_embed, name='stand_lookup') mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()], dims=[config.dim_input] + config.dim_hidden + [config.dim_output]) @@ -152,6 +164,7 @@ def main(): # Checkpoint('model.pkl', every_n_batches=100), Dump('taxi_model', every_n_batches=1000), LoadFromDump('taxi_model'), + FinishAfter(after_epoch=1) ] main_loop = MainLoop( @@ -163,13 +176,7 @@ def main(): main_loop.profile.report() # Produce an output on the test data - ''' - test = data.test_data - test = DataStream(test) - test = transformers.add_first_k(conifg.n_begin_end_pts, test) - test = transformers.add_last_k(config.n_begin_end_pts, test) - test = transformers.Select(test, ('trip_id', 'origin_stand', 'origin_call', 'first_k', 'last_k')) - test_stream = Batch(test, iteration_scheme=ConstantScheme(1000)) + test_stream = setup_test_stream() outfile = open("test-output.csv", "w") outcsv = csv.writer(outfile) @@ -177,9 +184,8 @@ def main(): for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']): dest = out['outputs'] for i, trip in enumerate(out['trip_id']): - outcsv.writerow([trip, repr(dest[i, 1]), repr(dest[i, 0])]) + outcsv.writerow([trip, repr(dest[i, 0]), repr(dest[i, 1])]) outfile.close() - ''' if __name__ == "__main__": |