From b07bf7612b07a08bd1298b87347889a71d390012 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Wed, 29 Apr 2015 16:59:30 -0400 Subject: Fixes. --- model.py | 24 ++++++++++++++++++------ transformers.py | 5 +++-- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/model.py b/model.py index 5caad8b..c5c75d3 100644 --- a/model.py +++ b/model.py @@ -65,7 +65,7 @@ def setup_stream(): load_in_memory=True) valid = DataStream(valid, iteration_scheme=SequentialExampleScheme(config.n_valid)) valid = transformers.add_first_k(config.n_begin_end_pts, valid) - valid = transformers.add_last_k(config.n_begin_end_pts, valid) + valid = transformers.add_random_k(config.n_begin_end_pts, valid) valid = transformers.add_destination(valid) valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k_latitude', 'last_k_latitude', 'first_k_longitude', 'last_k_longitude', @@ -88,23 +88,34 @@ def main(): y = tensor.concatenate((tensor.vector('destination_latitude')[:, None], tensor.vector('destination_longitude')[:, None]), axis=1) + # x_firstk_latitude = theano.printing.Print("x_firstk_latitude")(x_firstk_latitude) + # x_firstk_longitude = theano.printing.Print("x_firstk_longitude")(x_firstk_longitude) + # x_lastk_latitude = theano.printing.Print("x_lastk_latitude")(x_lastk_latitude) + # x_lastk_longitude = theano.printing.Print("x_lastk_longitude")(x_lastk_longitude) + # Define the model client_embed_table = LookupTable(length=config.n_clients+1, dim=config.dim_embed, name='client_lookup') stand_embed_table = LookupTable(length=config.n_stands+1, dim=config.dim_embed, name='stand_lookup') - mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [None], + mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()], dims=[config.dim_input] + config.dim_hidden + [config.dim_output]) # Create the Theano variables - client_embed = client_embed_table.apply(x_client).flatten(ndim=2) - stand_embed = stand_embed_table.apply(x_stand).flatten(ndim=2) + client_embed = client_embed_table.apply(x_client) + stand_embed = stand_embed_table.apply(x_stand) inputs = tensor.concatenate([x_firstk_latitude, x_firstk_longitude, - x_lastk_latitude, x_lastk_longitude, client_embed, stand_embed], + x_lastk_latitude, x_lastk_longitude, + client_embed, stand_embed], axis=1) # inputs = theano.printing.Print("inputs")(inputs) outputs = mlp.apply(inputs) # Normalize & Center + # outputs = theano.printing.Print("normal_outputs")(outputs) outputs = data.data_std * outputs + data.porto_center + + # outputs = theano.printing.Print("outputs")(outputs) + # y = theano.printing.Print("y")(y) + outputs.name = 'outputs' # Calculate the cost @@ -127,11 +138,12 @@ def main(): # Training cg = ComputationGraph(cost) + params = cg.parameters # VariableFilter(bricks=[Linear])(cg.parameters) algorithm = GradientDescent( cost=cost, # step_rule=AdaDelta(decay_rate=0.5), step_rule=Momentum(learning_rate=config.learning_rate, momentum=config.momentum), - params=cg.parameters) + params=params) extensions=[DataStreamMonitoring([cost, hcost], valid_stream, prefix='valid', diff --git a/transformers.py b/transformers.py index 3473fd5..c5b8d87 100644 --- a/transformers.py +++ b/transformers.py @@ -2,10 +2,11 @@ from fuel.transformers import Transformer, Filter, Mapping import numpy import theano import random +import data def at_least_k(k, v, pad_at_begin, is_longitude): if len(v) == 0: - v = numpy.array([41.1573 if is_longitude else -8.61612], dtype=theano.config.floatX) + v = numpy.array([data.porto_center[1 if is_longitude else 0]], dtype=theano.config.floatX) if len(v) < k: if pad_at_begin: v = numpy.concatenate((numpy.full((k - len(v),), v[0]), v)) @@ -40,7 +41,7 @@ def add_random_k(k, stream): id_longitude = stream.sources.index('longitude') def random_k(x): lat = at_least_k(k, x[id_latitude], True, False) - lon = at_least_k(k, x[id_latitude], True, True) + lon = at_least_k(k, x[id_longitude], True, True) loc = random.randrange(len(lat)-k+1) return (numpy.array(lat[loc:loc+k], dtype=theano.config.floatX), numpy.array(lon[loc:loc+k], dtype=theano.config.floatX)) -- cgit v1.2.3