aboutsummaryrefslogtreecommitdiff
path: root/model/simple_mlp.py
diff options
context:
space:
mode:
Diffstat (limited to 'model/simple_mlp.py')
-rw-r--r--model/simple_mlp.py30
1 files changed, 16 insertions, 14 deletions
diff --git a/model/simple_mlp.py b/model/simple_mlp.py
index 896ccd3..fc065f7 100644
--- a/model/simple_mlp.py
+++ b/model/simple_mlp.py
@@ -17,25 +17,27 @@ class Model(object):
x_lastk_latitude = (tensor.matrix('last_k_latitude') - data.porto_center[0]) / data.data_std[0]
x_lastk_longitude = (tensor.matrix('last_k_longitude') - data.porto_center[1]) / data.data_std[1]
- x_client = tensor.lvector('origin_call')
- x_stand = tensor.lvector('origin_stand')
+ input_list = [x_firstk_latitude, x_firstk_longitude, x_lastk_latitude, x_lastk_longitude]
+ embed_tables = []
+
+ self.require_inputs = ['first_k_latitude', 'first_k_longitude', 'last_k_latitude', 'last_k_longitude']
+
+ for (varname, num, dim) in config.dim_embeddings:
+ self.require_inputs.append(varname)
+ vardata = tensor.lvector(varname)
+ tbl = LookupTable(length=num, dim=dim, name='%s_lookup'%varname)
+ embed_tables.append(tbl)
+ input_list.append(tbl.apply(vardata))
y = tensor.concatenate((tensor.vector('destination_latitude')[:, None],
tensor.vector('destination_longitude')[:, None]), axis=1)
# Define the model
- client_embed_table = LookupTable(length=data.n_train_clients+1, dim=config.dim_embed, name='client_lookup')
- stand_embed_table = LookupTable(length=data.n_stands+1, dim=config.dim_embed, name='stand_lookup')
mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()],
dims=[config.dim_input] + config.dim_hidden + [config.dim_output])
# Create the Theano variables
- client_embed = client_embed_table.apply(x_client)
- stand_embed = stand_embed_table.apply(x_stand)
- inputs = tensor.concatenate([x_firstk_latitude, x_firstk_longitude,
- x_lastk_latitude, x_lastk_longitude,
- client_embed, stand_embed],
- axis=1)
+ inputs = tensor.concatenate(input_list, axis=1)
# inputs = theano.printing.Print("inputs")(inputs)
outputs = mlp.apply(inputs)
@@ -55,13 +57,13 @@ class Model(object):
hcost.name = 'hcost'
# Initialization
- client_embed_table.weights_init = IsotropicGaussian(0.001)
- stand_embed_table.weights_init = IsotropicGaussian(0.001)
+ for tbl in embed_tables:
+ tbl.weights_init = IsotropicGaussian(0.001)
mlp.weights_init = IsotropicGaussian(0.01)
mlp.biases_init = Constant(0.001)
- client_embed_table.initialize()
- stand_embed_table.initialize()
+ for tbl in embed_tables:
+ tbl.initialize()
mlp.initialize()
self.cost = cost