diff options
Diffstat (limited to 'model.py')
-rw-r--r-- | model.py | 40 |
1 files changed, 31 insertions, 9 deletions
@@ -2,6 +2,8 @@ import logging import os from argparse import ArgumentParser +import csv + import numpy import theano @@ -31,6 +33,7 @@ from blocks.extensions.monitoring import DataStreamMonitoring import data import transformers import hdist +import apply_model n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday n_dom = 31 @@ -43,7 +46,9 @@ n_begin_end_pts = 5 # how many points we consider at the beginning and end o n_end_pts = 5 dim_embed = 50 -dim_hidden = 200 +dim_input = n_begin_end_pts * 2 * 2 + dim_embed + dim_embed +dim_hidden = [200] +dim_output = 2 learning_rate = 0.002 momentum = 0.9 @@ -68,16 +73,15 @@ def main(): # Define the model client_embed_table = LookupTable(length=n_clients+1, dim=dim_embed, name='client_lookup') stand_embed_table = LookupTable(length=n_stands+1, dim=dim_embed, name='stand_lookup') - hidden_layer = MLP(activations=[Rectifier()], - dims=[n_begin_end_pts * 2 * 2 + dim_embed + dim_embed, dim_hidden]) - output_layer = Linear(input_dim=dim_hidden, output_dim=2) + hidden_layer = MLP(activations=[Rectifier() for _ in dim_hidden], + dims=[dim_input] + dim_hidden) + output_layer = Linear(input_dim=dim_hidden[-1], output_dim=dim_output) # Create the Theano variables client_embed = client_embed_table.apply(x_client).flatten(ndim=2) stand_embed = stand_embed_table.apply(x_stand).flatten(ndim=2) - inputs = tensor.concatenate([x_firstk, x_lastk, - client_embed.zeros_like(), stand_embed.zeros_like()], + inputs = tensor.concatenate([x_firstk, x_lastk, client_embed, stand_embed], axis=1) # inputs = theano.printing.Print("inputs")(inputs) hidden = hidden_layer.apply(inputs) @@ -86,6 +90,7 @@ def main(): # Normalize & Center outputs = data.data_std * outputs + data.porto_center + outputs.name = 'outputs' # Calculate the cost cost = (outputs - y).norm(2, axis=1).mean() @@ -121,7 +126,7 @@ def main(): valid = transformers.add_last_k(n_begin_end_pts, valid) valid = transformers.concat_destination_xy(valid) valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination')) - valid_stream = Batch(valid, iteration_scheme=ConstantScheme(batch_size)) + valid_stream = Batch(valid, iteration_scheme=ConstantScheme(1000)) # Training @@ -135,8 +140,8 @@ def main(): extensions=[DataStreamMonitoring([cost, hcost], valid_stream, prefix='valid', - every_n_batches=1000), - Printing(every_n_batches=1000), + every_n_batches=1), + Printing(every_n_batches=1), # Dump('taxi_model', every_n_batches=100), # LoadFromDump('taxi_model'), ] @@ -148,6 +153,23 @@ def main(): extensions=extensions) main_loop.run() + # Produce an output on the test data + test = data.test_data + test = DataStream(test) + test = transformers.add_first_k(n_begin_end_pts, test) + test = transformers.add_last_k(n_begin_end_pts, test) + test = transformers.Select(test, ('trip_id', 'origin_stand', 'origin_call', 'first_k', 'last_k')) + test_stream = Batch(test, iteration_scheme=ConstantScheme(1000)) + + outfile = open("test-output.csv", "w") + outcsv = csv.writer(outfile) + for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']): + dest = out['outputs'] + for i, trip in enumerate(out['trip_id']): + outcsv.writerow([trip, repr(dest[i, 1]), repr(dest[i, 0])]) + outfile.close() + + if __name__ == "__main__": logging.basicConfig(level=logging.INFO) main() |