diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-05 14:15:21 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-05 14:15:21 -0400 |
commit | 54613c1f9cf510ca7a71d6619418f2247515aec6 (patch) | |
tree | bed9a5a11ef5b7feecee44095a29400e32f76b05 /train.py | |
parent | 712035b88be1816d3fbd58ce69ae6464767c780e (diff) | |
download | taxi-54613c1f9cf510ca7a71d6619418f2247515aec6.tar.gz taxi-54613c1f9cf510ca7a71d6619418f2247515aec6.zip |
Add models for time predictioAdd models for time prediction
Diffstat (limited to 'train.py')
-rw-r--r-- | train.py | 27 |
1 files changed, 16 insertions, 11 deletions
@@ -20,7 +20,7 @@ from blocks.model import Model from fuel.datasets.hdf5 import H5PYDataset from fuel.transformers import Batch from fuel.streams import DataStream -from fuel.schemes import ConstantScheme, SequentialExampleScheme +from fuel.schemes import ConstantScheme, SequentialExampleScheme, ShuffledExampleScheme from blocks.algorithms import GradientDescent, Scale, AdaDelta, Momentum from blocks.graph import ComputationGraph @@ -31,7 +31,6 @@ from blocks.extensions.monitoring import DataStreamMonitoring import data import transformers -import hdist import apply_model if __name__ == "__main__": @@ -48,7 +47,7 @@ def setup_train_stream(req_vars): which_set='train', subset=slice(0, data.dataset_size), load_in_memory=True) - train = DataStream(train, iteration_scheme=SequentialExampleScheme(data.dataset_size - config.n_valid)) + train = DataStream(train, iteration_scheme=ShuffledExampleScheme(data.dataset_size)) train = transformers.TaxiExcludeTrips(data.valid_trips, train) train = transformers.TaxiGenerateSplits(train, max_splits=100) @@ -91,10 +90,9 @@ def main(): model = config.model.Model(config) cost = model.cost - hcost = model.hcost outputs = model.outputs - req_vars = model.require_inputs + [ 'destination_latitude', 'destination_longitude' ] + req_vars = model.require_inputs + model.pred_vars req_vars_test = model.require_inputs + [ 'trip_id' ] train_stream = setup_train_stream(req_vars) @@ -109,7 +107,7 @@ def main(): step_rule=Momentum(learning_rate=config.learning_rate, momentum=config.momentum), params=params) - extensions=[DataStreamMonitoring([cost, hcost], valid_stream, + extensions=[DataStreamMonitoring(model.monitor, valid_stream, prefix='valid', every_n_batches=1000), Printing(every_n_batches=1000), @@ -132,11 +130,18 @@ def main(): outfile = open("output/test-output-%s.csv" % model_name, "w") outcsv = csv.writer(outfile) - outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"]) - for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']): - dest = out['outputs'] - for i, trip in enumerate(out['trip_id']): - outcsv.writerow([trip, repr(dest[i, 0]), repr(dest[i, 1])]) + if model.pred_vars == ['time']: + outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"]) + for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']): + time = out['outputs'] + for i, trip in enumerate(out['trip_id']): + outcsv.writerow([trip, int(time[i, 0])]) + else: + outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"]) + for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']): + dest = out['outputs'] + for i, trip in enumerate(out['trip_id']): + outcsv.writerow([trip, repr(dest[i, 0]), repr(dest[i, 1])]) outfile.close() |