diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-05 10:56:21 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-05 10:56:21 -0400 |
commit | 5b496677ea1db59a6718e5c9b2958177c76cb25f (patch) | |
tree | 5eb66c6c12450edda762de94e46f5aeac805ac93 /train.py | |
parent | 95b565afb7e1c2a6eb23ca9f7c13cd6efaf55a39 (diff) | |
download | taxi-5b496677ea1db59a6718e5c9b2958177c76cb25f.tar.gz taxi-5b496677ea1db59a6718e5c9b2958177c76cb25f.zip |
Refactor architecture so that embedding sizes can be easily changed.
Diffstat (limited to 'train.py')
-rw-r--r-- | train.py | 38 |
1 files changed, 20 insertions, 18 deletions
@@ -42,7 +42,7 @@ if __name__ == "__main__": config = importlib.import_module(model_name) -def setup_train_stream(): +def setup_train_stream(req_vars): # Load the training and test data train = H5PYDataset(data.H5DATA_PATH, which_set='train', @@ -51,34 +51,33 @@ def setup_train_stream(): train = DataStream(train, iteration_scheme=SequentialExampleScheme(data.dataset_size - config.n_valid)) train = transformers.filter_out_trips(data.valid_trips, train) train = transformers.TaxiGenerateSplits(train, max_splits=100) + train = transformers.add_first_k(config.n_begin_end_pts, train) train = transformers.add_last_k(config.n_begin_end_pts, train) - train = transformers.Select(train, ('origin_stand', 'origin_call', 'first_k_latitude', - 'last_k_latitude', 'first_k_longitude', 'last_k_longitude', - 'destination_latitude', 'destination_longitude')) + train = transformers.Select(train, tuple(req_vars)) + train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size)) return train_stream -def setup_valid_stream(): +def setup_valid_stream(req_vars): valid = DataStream(data.valid_data) + valid = transformers.add_first_k(config.n_begin_end_pts, valid) valid = transformers.add_last_k(config.n_begin_end_pts, valid) - valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k_latitude', - 'last_k_latitude', 'first_k_longitude', 'last_k_longitude', - 'destination_latitude', 'destination_longitude')) + valid = transformers.Select(valid, tuple(req_vars)) + valid_stream = Batch(valid, iteration_scheme=ConstantScheme(1000)) return valid_stream -def setup_test_stream(): - test = data.test_data +def setup_test_stream(req_vars): + test = DataStream(data.test_data) - test = DataStream(test) test = transformers.add_first_k(config.n_begin_end_pts, test) test = transformers.add_last_k(config.n_begin_end_pts, test) - test = transformers.Select(test, ('trip_id', 'origin_stand', 'origin_call', 'first_k_latitude', - 'last_k_latitude', 'first_k_longitude', 'last_k_longitude')) + test = transformers.Select(test, tuple(req_vars)) + test_stream = Batch(test, iteration_scheme=ConstantScheme(1000)) return test_stream @@ -91,8 +90,11 @@ def main(): hcost = model.hcost outputs = model.outputs - train_stream = setup_train_stream() - valid_stream = setup_valid_stream() + req_vars = model.require_inputs + [ 'destination_latitude', 'destination_longitude' ] + req_vars_test = model.require_inputs + [ 'trip_id' ] + + train_stream = setup_train_stream(req_vars) + valid_stream = setup_valid_stream(req_vars) # Training cg = ComputationGraph(cost) @@ -110,7 +112,7 @@ def main(): # Checkpoint('model.pkl', every_n_batches=100), Dump('model_data/' + model_name, every_n_batches=1000), LoadFromDump('model_data/' + model_name), - FinishAfter(after_epoch=10), + FinishAfter(after_epoch=42), ] main_loop = MainLoop( @@ -122,9 +124,9 @@ def main(): main_loop.profile.report() # Produce an output on the test data - test_stream = setup_test_stream() + test_stream = setup_test_stream(req_vars_test) - outfile = open("test-output-%s.csv" % model_name, "w") + outfile = open("output/test-output-%s.csv" % model_name, "w") outcsv = csv.writer(outfile) outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"]) for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']): |