aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rwxr-xr-x[-rw-r--r--]train.py51
1 files changed, 21 insertions, 30 deletions
diff --git a/train.py b/train.py
index 4cbd526..9e915ed 100644..100755
--- a/train.py
+++ b/train.py
@@ -1,36 +1,26 @@
-import logging
-import os
+#!/usr/bin/env python
+
import sys
+import logging
import importlib
-from argparse import ArgumentParser
import csv
-import numpy
-
-import theano
-from theano import printing
-from theano import tensor
-from theano.ifelse import ifelse
-
-from blocks.filter import VariableFilter
-
from blocks.model import Model
-from fuel.datasets.hdf5 import H5PYDataset
from fuel.transformers import Batch
from fuel.streams import DataStream
-from fuel.schemes import ConstantScheme, SequentialExampleScheme, ShuffledExampleScheme
+from fuel.schemes import ConstantScheme, ShuffledExampleScheme
-from blocks.algorithms import GradientDescent, Scale, AdaDelta, Momentum
+from blocks.algorithms import GradientDescent, AdaDelta, Momentum
from blocks.graph import ComputationGraph
from blocks.main_loop import MainLoop
from blocks.extensions import Printing, FinishAfter
from blocks.extensions.saveload import Dump, LoadFromDump, Checkpoint
from blocks.extensions.monitoring import DataStreamMonitoring
-import data
-import transformers
+from data import transformers
+from data.hdf5 import TaxiDataset, TaxiStream
import apply_model
if __name__ == "__main__":
@@ -38,18 +28,18 @@ if __name__ == "__main__":
print >> sys.stderr, 'Usage: %s config' % sys.argv[0]
sys.exit(1)
model_name = sys.argv[1]
- config = importlib.import_module(model_name)
+ config = importlib.import_module('.%s' % model_name, 'config')
+def compile_valid_trip_ids():
+ valid = TaxiDataset(config.valid_set, 'valid.hdf5', sources=('trip_id',))
+ ids = valid.get_data(None, slice(0, valid.num_examples))
+ return set(ids[0])
-def setup_train_stream(req_vars):
- # Load the training and test data
- train = H5PYDataset(data.H5DATA_PATH,
- which_set='train',
- subset=slice(0, data.dataset_size),
- load_in_memory=True)
- train = DataStream(train, iteration_scheme=ShuffledExampleScheme(data.dataset_size))
+def setup_train_stream(req_vars, valid_trips_ids):
+ train = TaxiDataset('train')
+ train = DataStream(train, iteration_scheme=ShuffledExampleScheme(train.num_examples))
- train = transformers.TaxiExcludeTrips(data.valid_trips, train)
+ train = transformers.TaxiExcludeTrips(valid_trips_ids, train)
train = transformers.TaxiGenerateSplits(train, max_splits=100)
train = transformers.TaxiAddDateTime(train)
@@ -62,7 +52,7 @@ def setup_train_stream(req_vars):
return train_stream
def setup_valid_stream(req_vars):
- valid = DataStream(data.valid_data)
+ valid = TaxiStream(config.valid_set, 'valid.hdf5')
valid = transformers.TaxiAddDateTime(valid)
valid = transformers.TaxiAddFirstK(config.n_begin_end_pts, valid)
@@ -74,7 +64,7 @@ def setup_valid_stream(req_vars):
return valid_stream
def setup_test_stream(req_vars):
- test = DataStream(data.test_data)
+ test = TaxiStream('test')
test = transformers.TaxiAddDateTime(test)
test = transformers.TaxiAddFirstK(config.n_begin_end_pts, test)
@@ -95,12 +85,13 @@ def main():
req_vars = model.require_inputs + model.pred_vars
req_vars_test = model.require_inputs + [ 'trip_id' ]
- train_stream = setup_train_stream(req_vars)
+ valid_trips_ids = compile_valid_trip_ids()
+ train_stream = setup_train_stream(req_vars, valid_trips_ids)
valid_stream = setup_valid_stream(req_vars)
# Training
cg = ComputationGraph(cost)
- params = cg.parameters # VariableFilter(bricks=[Linear])(cg.parameters)
+ params = cg.parameters
algorithm = GradientDescent(
cost=cost,
# step_rule=AdaDelta(decay_rate=0.5),