diff options
Diffstat (limited to 'train.py')
-rwxr-xr-x[-rw-r--r--] | train.py | 51 |
1 files changed, 21 insertions, 30 deletions
@@ -1,36 +1,26 @@ -import logging -import os +#!/usr/bin/env python + import sys +import logging import importlib -from argparse import ArgumentParser import csv -import numpy - -import theano -from theano import printing -from theano import tensor -from theano.ifelse import ifelse - -from blocks.filter import VariableFilter - from blocks.model import Model -from fuel.datasets.hdf5 import H5PYDataset from fuel.transformers import Batch from fuel.streams import DataStream -from fuel.schemes import ConstantScheme, SequentialExampleScheme, ShuffledExampleScheme +from fuel.schemes import ConstantScheme, ShuffledExampleScheme -from blocks.algorithms import GradientDescent, Scale, AdaDelta, Momentum +from blocks.algorithms import GradientDescent, AdaDelta, Momentum from blocks.graph import ComputationGraph from blocks.main_loop import MainLoop from blocks.extensions import Printing, FinishAfter from blocks.extensions.saveload import Dump, LoadFromDump, Checkpoint from blocks.extensions.monitoring import DataStreamMonitoring -import data -import transformers +from data import transformers +from data.hdf5 import TaxiDataset, TaxiStream import apply_model if __name__ == "__main__": @@ -38,18 +28,18 @@ if __name__ == "__main__": print >> sys.stderr, 'Usage: %s config' % sys.argv[0] sys.exit(1) model_name = sys.argv[1] - config = importlib.import_module(model_name) + config = importlib.import_module('.%s' % model_name, 'config') +def compile_valid_trip_ids(): + valid = TaxiDataset(config.valid_set, 'valid.hdf5', sources=('trip_id',)) + ids = valid.get_data(None, slice(0, valid.num_examples)) + return set(ids[0]) -def setup_train_stream(req_vars): - # Load the training and test data - train = H5PYDataset(data.H5DATA_PATH, - which_set='train', - subset=slice(0, data.dataset_size), - load_in_memory=True) - train = DataStream(train, iteration_scheme=ShuffledExampleScheme(data.dataset_size)) +def setup_train_stream(req_vars, valid_trips_ids): + train = TaxiDataset('train') + train = DataStream(train, iteration_scheme=ShuffledExampleScheme(train.num_examples)) - train = transformers.TaxiExcludeTrips(data.valid_trips, train) + train = transformers.TaxiExcludeTrips(valid_trips_ids, train) train = transformers.TaxiGenerateSplits(train, max_splits=100) train = transformers.TaxiAddDateTime(train) @@ -62,7 +52,7 @@ def setup_train_stream(req_vars): return train_stream def setup_valid_stream(req_vars): - valid = DataStream(data.valid_data) + valid = TaxiStream(config.valid_set, 'valid.hdf5') valid = transformers.TaxiAddDateTime(valid) valid = transformers.TaxiAddFirstK(config.n_begin_end_pts, valid) @@ -74,7 +64,7 @@ def setup_valid_stream(req_vars): return valid_stream def setup_test_stream(req_vars): - test = DataStream(data.test_data) + test = TaxiStream('test') test = transformers.TaxiAddDateTime(test) test = transformers.TaxiAddFirstK(config.n_begin_end_pts, test) @@ -95,12 +85,13 @@ def main(): req_vars = model.require_inputs + model.pred_vars req_vars_test = model.require_inputs + [ 'trip_id' ] - train_stream = setup_train_stream(req_vars) + valid_trips_ids = compile_valid_trip_ids() + train_stream = setup_train_stream(req_vars, valid_trips_ids) valid_stream = setup_valid_stream(req_vars) # Training cg = ComputationGraph(cost) - params = cg.parameters # VariableFilter(bricks=[Linear])(cg.parameters) + params = cg.parameters algorithm = GradientDescent( cost=cost, # step_rule=AdaDelta(decay_rate=0.5), |