aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-05 14:15:21 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-05 14:15:21 -0400
commit54613c1f9cf510ca7a71d6619418f2247515aec6 (patch)
treebed9a5a11ef5b7feecee44095a29400e32f76b05 /train.py
parent712035b88be1816d3fbd58ce69ae6464767c780e (diff)
downloadtaxi-54613c1f9cf510ca7a71d6619418f2247515aec6.tar.gz
taxi-54613c1f9cf510ca7a71d6619418f2247515aec6.zip
Add models for time predictioAdd models for time prediction
Diffstat (limited to 'train.py')
-rw-r--r--train.py27
1 files changed, 16 insertions, 11 deletions
diff --git a/train.py b/train.py
index 2c9522e..4cbd526 100644
--- a/train.py
+++ b/train.py
@@ -20,7 +20,7 @@ from blocks.model import Model
from fuel.datasets.hdf5 import H5PYDataset
from fuel.transformers import Batch
from fuel.streams import DataStream
-from fuel.schemes import ConstantScheme, SequentialExampleScheme
+from fuel.schemes import ConstantScheme, SequentialExampleScheme, ShuffledExampleScheme
from blocks.algorithms import GradientDescent, Scale, AdaDelta, Momentum
from blocks.graph import ComputationGraph
@@ -31,7 +31,6 @@ from blocks.extensions.monitoring import DataStreamMonitoring
import data
import transformers
-import hdist
import apply_model
if __name__ == "__main__":
@@ -48,7 +47,7 @@ def setup_train_stream(req_vars):
which_set='train',
subset=slice(0, data.dataset_size),
load_in_memory=True)
- train = DataStream(train, iteration_scheme=SequentialExampleScheme(data.dataset_size - config.n_valid))
+ train = DataStream(train, iteration_scheme=ShuffledExampleScheme(data.dataset_size))
train = transformers.TaxiExcludeTrips(data.valid_trips, train)
train = transformers.TaxiGenerateSplits(train, max_splits=100)
@@ -91,10 +90,9 @@ def main():
model = config.model.Model(config)
cost = model.cost
- hcost = model.hcost
outputs = model.outputs
- req_vars = model.require_inputs + [ 'destination_latitude', 'destination_longitude' ]
+ req_vars = model.require_inputs + model.pred_vars
req_vars_test = model.require_inputs + [ 'trip_id' ]
train_stream = setup_train_stream(req_vars)
@@ -109,7 +107,7 @@ def main():
step_rule=Momentum(learning_rate=config.learning_rate, momentum=config.momentum),
params=params)
- extensions=[DataStreamMonitoring([cost, hcost], valid_stream,
+ extensions=[DataStreamMonitoring(model.monitor, valid_stream,
prefix='valid',
every_n_batches=1000),
Printing(every_n_batches=1000),
@@ -132,11 +130,18 @@ def main():
outfile = open("output/test-output-%s.csv" % model_name, "w")
outcsv = csv.writer(outfile)
- outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"])
- for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']):
- dest = out['outputs']
- for i, trip in enumerate(out['trip_id']):
- outcsv.writerow([trip, repr(dest[i, 0]), repr(dest[i, 1])])
+ if model.pred_vars == ['time']:
+ outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"])
+ for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']):
+ time = out['outputs']
+ for i, trip in enumerate(out['trip_id']):
+ outcsv.writerow([trip, int(time[i, 0])])
+ else:
+ outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"])
+ for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']):
+ dest = out['outputs']
+ for i, trip in enumerate(out['trip_id']):
+ outcsv.writerow([trip, repr(dest[i, 0]), repr(dest[i, 1])])
outfile.close()