diff options
Diffstat (limited to 'model.py')
-rw-r--r-- | model.py | 22 |
1 files changed, 14 insertions, 8 deletions
@@ -29,7 +29,7 @@ n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday n_dom = 31 n_hour = 24 -n_clients = 57106 +n_clients = 57105 n_stands = 63 n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory @@ -45,9 +45,9 @@ def main(): # The input and the targets x_firstk = tensor.matrix('first_k') x_lastk = tensor.matrix('last_k') - x_client = tensor.lmatrix('origin_call') - x_stand = tensor.lmatrix('origin_stand') - y = tensor.vector('destination') + x_client = tensor.lvector('origin_call') + x_stand = tensor.lvector('origin_stand') + y = tensor.matrix('destination') # Define the model client_embed_table = LookupTable(length=n_clients+1, dim=dim_embed, name='client_lookup') @@ -60,12 +60,15 @@ def main(): client_embed = client_embed_table.apply(x_client).flatten(ndim=2) stand_embed = stand_embed_table.apply(x_stand).flatten(ndim=2) - inputs = tensor.concatenate([x_firstk, x_lastk, client_embed, stand_embed], axis=1) + inputs = tensor.concatenate([x_firstk, x_lastk, + client_embed, stand_embed], + axis=1) hidden = hidden_layer.apply(inputs) outputs = output_layer.apply(hidden) # Calculate the cost cost = (outputs - y).norm(2, axis=1).mean() + cost.name = 'cost' # Initialization client_embed_table.weights_init = IsotropicGaussian(0.001) @@ -83,12 +86,14 @@ def main(): train = DataStream(train) train = transformers.add_extremities(train, n_begin_end_pts) train = transformers.add_destination(train) + train = transformers.Select(train, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination')) train_stream = Batch(train, iteration_scheme=ConstantScheme(batch_size)) valid = data.valid_data valid = DataStream(valid) valid = transformers.add_extremities(valid, n_begin_end_pts) valid = transformers.add_destination(valid) + valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination')) valid_stream = Batch(valid, iteration_scheme=ConstantScheme(batch_size)) @@ -103,9 +108,10 @@ def main(): extensions=[DataStreamMonitoring([cost], valid_stream, prefix='valid', every_n_batches=100), - Printing(every_n_batches=100), - Dump('ngram_blocks_model', every_n_batches=100), - LoadFromDump('ngram_blocks_model')] + Printing(every_n_batches=100), + # Dump('taxi_model', every_n_batches=100), + # LoadFromDump('taxi_model'), + ] main_loop = MainLoop( model=Model([cost]), |