aboutsummaryrefslogtreecommitdiff
path: root/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'model.py')
-rw-r--r--model.py36
1 files changed, 21 insertions, 15 deletions
diff --git a/model.py b/model.py
index 6a8be9d..065db7e 100644
--- a/model.py
+++ b/model.py
@@ -29,7 +29,7 @@ from fuel.schemes import ConstantScheme, SequentialExampleScheme
from blocks.algorithms import GradientDescent, Scale, AdaDelta, Momentum
from blocks.graph import ComputationGraph
from blocks.main_loop import MainLoop
-from blocks.extensions import Printing
+from blocks.extensions import Printing, FinishAfter
from blocks.extensions.saveload import Dump, LoadFromDump, Checkpoint
from blocks.extensions.monitoring import DataStreamMonitoring
@@ -48,9 +48,9 @@ def setup_stream():
# Load the training and test data
train = H5PYDataset('/data/lisatmp3/simonet/taxi/data.hdf5',
which_set='train',
- subset=slice(0, config.train_size - config.n_valid),
+ subset=slice(0, data.dataset_size - config.n_valid),
load_in_memory=True)
- train = DataStream(train, iteration_scheme=SequentialExampleScheme(config.train_size - config.n_valid))
+ train = DataStream(train, iteration_scheme=SequentialExampleScheme(data.dataset_size - config.n_valid))
train = transformers.add_first_k(config.n_begin_end_pts, train)
train = transformers.add_random_k(config.n_begin_end_pts, train)
train = transformers.add_destination(train)
@@ -61,7 +61,7 @@ def setup_stream():
valid = H5PYDataset('/data/lisatmp3/simonet/taxi/data.hdf5',
which_set='train',
- subset=slice(config.train_size - config.n_valid, config.train_size),
+ subset=slice(data.dataset_size - config.n_valid, data.dataset_size),
load_in_memory=True)
valid = DataStream(valid, iteration_scheme=SequentialExampleScheme(config.n_valid))
valid = transformers.add_first_k(config.n_begin_end_pts, valid)
@@ -74,6 +74,18 @@ def setup_stream():
return (train_stream, valid_stream)
+def setup_test_stream():
+ test = data.test_data
+
+ test = DataStream(test)
+ test = transformers.add_first_k(config.n_begin_end_pts, test)
+ test = transformers.add_last_k(config.n_begin_end_pts, test)
+ test = transformers.Select(test, ('trip_id', 'origin_stand', 'origin_call', 'first_k_latitude',
+ 'last_k_latitude', 'first_k_longitude', 'last_k_longitude'))
+ test_stream = Batch(test, iteration_scheme=ConstantScheme(1000))
+
+ return test_stream
+
def main():
# The input and the targets
x_firstk_latitude = (tensor.matrix('first_k_latitude') - data.porto_center[0]) / data.data_std[0]
@@ -94,8 +106,8 @@ def main():
# x_lastk_longitude = theano.printing.Print("x_lastk_longitude")(x_lastk_longitude)
# Define the model
- client_embed_table = LookupTable(length=config.n_clients+1, dim=config.dim_embed, name='client_lookup')
- stand_embed_table = LookupTable(length=config.n_stands+1, dim=config.dim_embed, name='stand_lookup')
+ client_embed_table = LookupTable(length=data.n_clients+1, dim=config.dim_embed, name='client_lookup')
+ stand_embed_table = LookupTable(length=data.n_stands+1, dim=config.dim_embed, name='stand_lookup')
mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()],
dims=[config.dim_input] + config.dim_hidden + [config.dim_output])
@@ -152,6 +164,7 @@ def main():
# Checkpoint('model.pkl', every_n_batches=100),
Dump('taxi_model', every_n_batches=1000),
LoadFromDump('taxi_model'),
+ FinishAfter(after_epoch=1)
]
main_loop = MainLoop(
@@ -163,13 +176,7 @@ def main():
main_loop.profile.report()
# Produce an output on the test data
- '''
- test = data.test_data
- test = DataStream(test)
- test = transformers.add_first_k(conifg.n_begin_end_pts, test)
- test = transformers.add_last_k(config.n_begin_end_pts, test)
- test = transformers.Select(test, ('trip_id', 'origin_stand', 'origin_call', 'first_k', 'last_k'))
- test_stream = Batch(test, iteration_scheme=ConstantScheme(1000))
+ test_stream = setup_test_stream()
outfile = open("test-output.csv", "w")
outcsv = csv.writer(outfile)
@@ -177,9 +184,8 @@ def main():
for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']):
dest = out['outputs']
for i, trip in enumerate(out['trip_id']):
- outcsv.writerow([trip, repr(dest[i, 1]), repr(dest[i, 0])])
+ outcsv.writerow([trip, repr(dest[i, 0]), repr(dest[i, 1])])
outfile.close()
- '''
if __name__ == "__main__":