aboutsummaryrefslogtreecommitdiff
path: root/model.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-04 13:15:23 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-04 13:15:23 -0400
commitde76aae44b6c0cbe9ab42c7ae215c3ae9e4e4055 (patch)
tree09c09a12861f0f6826cd33e3b77eba9a07076c49 /model.py
parent43e106e6630030dd34813295fe1d07bb86025402 (diff)
downloadtaxi-de76aae44b6c0cbe9ab42c7ae215c3ae9e4e4055.tar.gz
taxi-de76aae44b6c0cbe9ab42c7ae215c3ae9e4e4055.zip
Add TaxiGenerateSplits
Diffstat (limited to 'model.py')
-rw-r--r--model.py28
1 files changed, 15 insertions, 13 deletions
diff --git a/model.py b/model.py
index aff9fd7..753fb01 100644
--- a/model.py
+++ b/model.py
@@ -44,35 +44,35 @@ if __name__ == "__main__":
sys.exit(1)
config = importlib.import_module(sys.argv[1])
-def setup_stream():
+
+def setup_train_stream():
# Load the training and test data
train = H5PYDataset(data.H5DATA_PATH,
which_set='train',
- subset=slice(0, data.dataset_size - config.n_valid),
+ subset=slice(0, data.dataset_size),
load_in_memory=True)
train = DataStream(train, iteration_scheme=SequentialExampleScheme(data.dataset_size - config.n_valid))
+ train = transformers.filter_out_trips(data.valid_trips, train)
+ train = transformers.TaxiGenerateSplits(train)
train = transformers.add_first_k(config.n_begin_end_pts, train)
- train = transformers.add_random_k(config.n_begin_end_pts, train)
- train = transformers.add_destination(train)
+ train = transformers.add_last_k(config.n_begin_end_pts, train)
train = transformers.Select(train, ('origin_stand', 'origin_call', 'first_k_latitude',
'last_k_latitude', 'first_k_longitude', 'last_k_longitude',
'destination_latitude', 'destination_longitude'))
train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size))
- valid = H5PYDataset(data.H5DATA_PATH,
- which_set='train',
- subset=slice(data.dataset_size - config.n_valid, data.dataset_size),
- load_in_memory=True)
- valid = DataStream(valid, iteration_scheme=SequentialExampleScheme(config.n_valid))
+ return train_stream
+
+def setup_valid_stream():
+ valid = DataStream(data.valid_data)
valid = transformers.add_first_k(config.n_begin_end_pts, valid)
- valid = transformers.add_random_k(config.n_begin_end_pts, valid)
- valid = transformers.add_destination(valid)
+ valid = transformers.add_last_k(config.n_begin_end_pts, valid)
valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k_latitude',
'last_k_latitude', 'first_k_longitude', 'last_k_longitude',
'destination_latitude', 'destination_longitude'))
valid_stream = Batch(valid, iteration_scheme=ConstantScheme(1000))
- return (train_stream, valid_stream)
+ return valid_stream
def setup_test_stream():
test = data.test_data
@@ -86,6 +86,7 @@ def setup_test_stream():
return test_stream
+
def main():
# The input and the targets
x_firstk_latitude = (tensor.matrix('first_k_latitude') - data.porto_center[0]) / data.data_std[0]
@@ -146,7 +147,8 @@ def main():
stand_embed_table.initialize()
mlp.initialize()
- (train_stream, valid_stream) = setup_stream()
+ train_stream = setup_train_stream()
+ valid_stream = setup_valid_stream()
# Training
cg = ComputationGraph(cost)