aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--data.py10
-rw-r--r--model.py28
-rw-r--r--transformers.py51
3 files changed, 73 insertions, 16 deletions
diff --git a/data.py b/data.py
index 92aa062..730a9ab 100644
--- a/data.py
+++ b/data.py
@@ -164,21 +164,25 @@ taxi_columns = [
]
taxi_columns_valid = taxi_columns + [
- ("destination_longitude", lambda l: float(l[9])),
- ("destination_latitude", lambda l: float(l[10])),
+ ("destination_longitude", lambda l: numpy.float32(float(l[9]))),
+ ("destination_latitude", lambda l: numpy.float32(float(l[10]))),
("time", lambda l: int(l[11])),
]
train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)]
-valid_files=["%s/split/valid.csv" % (DATA_PATH,)]
+valid_files=["%s/split/valid2-cut.csv" % (DATA_PATH,)]
test_file="%s/test.csv" % (DATA_PATH,)
train_data=TaxiData(train_files, taxi_columns)
valid_data = TaxiData(valid_files, taxi_columns_valid)
test_data = TaxiData(test_file, taxi_columns, has_header=True)
+valid_trips = [l for l in open(DATA_PATH + "/split/valid2-cut-ids.txt")]
+
def train_it():
return DataIterator(DataStream(train_data))
def test_it():
return DataIterator(DataStream(valid_data))
+
+
diff --git a/model.py b/model.py
index aff9fd7..753fb01 100644
--- a/model.py
+++ b/model.py
@@ -44,35 +44,35 @@ if __name__ == "__main__":
sys.exit(1)
config = importlib.import_module(sys.argv[1])
-def setup_stream():
+
+def setup_train_stream():
# Load the training and test data
train = H5PYDataset(data.H5DATA_PATH,
which_set='train',
- subset=slice(0, data.dataset_size - config.n_valid),
+ subset=slice(0, data.dataset_size),
load_in_memory=True)
train = DataStream(train, iteration_scheme=SequentialExampleScheme(data.dataset_size - config.n_valid))
+ train = transformers.filter_out_trips(data.valid_trips, train)
+ train = transformers.TaxiGenerateSplits(train)
train = transformers.add_first_k(config.n_begin_end_pts, train)
- train = transformers.add_random_k(config.n_begin_end_pts, train)
- train = transformers.add_destination(train)
+ train = transformers.add_last_k(config.n_begin_end_pts, train)
train = transformers.Select(train, ('origin_stand', 'origin_call', 'first_k_latitude',
'last_k_latitude', 'first_k_longitude', 'last_k_longitude',
'destination_latitude', 'destination_longitude'))
train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size))
- valid = H5PYDataset(data.H5DATA_PATH,
- which_set='train',
- subset=slice(data.dataset_size - config.n_valid, data.dataset_size),
- load_in_memory=True)
- valid = DataStream(valid, iteration_scheme=SequentialExampleScheme(config.n_valid))
+ return train_stream
+
+def setup_valid_stream():
+ valid = DataStream(data.valid_data)
valid = transformers.add_first_k(config.n_begin_end_pts, valid)
- valid = transformers.add_random_k(config.n_begin_end_pts, valid)
- valid = transformers.add_destination(valid)
+ valid = transformers.add_last_k(config.n_begin_end_pts, valid)
valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k_latitude',
'last_k_latitude', 'first_k_longitude', 'last_k_longitude',
'destination_latitude', 'destination_longitude'))
valid_stream = Batch(valid, iteration_scheme=ConstantScheme(1000))
- return (train_stream, valid_stream)
+ return valid_stream
def setup_test_stream():
test = data.test_data
@@ -86,6 +86,7 @@ def setup_test_stream():
return test_stream
+
def main():
# The input and the targets
x_firstk_latitude = (tensor.matrix('first_k_latitude') - data.porto_center[0]) / data.data_std[0]
@@ -146,7 +147,8 @@ def main():
stand_embed_table.initialize()
mlp.initialize()
- (train_stream, valid_stream) = setup_stream()
+ train_stream = setup_train_stream()
+ valid_stream = setup_valid_stream()
# Training
cg = ComputationGraph(cost)
diff --git a/transformers.py b/transformers.py
index 5ad9a87..79e8327 100644
--- a/transformers.py
+++ b/transformers.py
@@ -27,6 +27,42 @@ class Select(Transformer):
data=next(self.child_epoch_iterator)
return [data[id] for id in self.ids]
+class TaxiGenerateSplits(Transformer):
+ def __init__(self, data_stream, max_splits=-1):
+ super(TaxiGenerateSplits, self).__init__(data_stream)
+ self.sources = data_stream.sources + ('destination_latitude', 'destination_longitude')
+ self.max_splits = max_splits
+ self.data = None
+ self.splits = []
+ self.isplit = 0
+ self.id_latitude = data_stream.sources.index('latitude')
+ self.id_longitude = data_stream.sources.index('longitude')
+
+ def get_data(self, request=None):
+ if request is not None:
+ raise ValueError
+ while self.isplit >= len(self.splits):
+ self.data = next(self.child_epoch_iterator)
+ self.splits = range(len(self.data[self.id_polyline]))
+ random.shuffle_array(self.splits)
+ if self.max_splits != -1 and len(self.splits) > self.max_splits:
+ self.splits = self.splits[:self.max_splits]
+ self.isplit = 0
+
+ i = self.isplit
+ self.isplit += 1
+ n = self.splits[i]+1
+
+ r = list(self.data)
+
+ r[self.id_latitude] = r[self.id_latitude][:n]
+ r[self.id_longitude] = r[self.id_longitude][:n]
+
+ dlat = self.data[self.id_latitude][-1]
+ dlon = self.data[self.id_longitude][-1]
+
+ return tuple(r + [dlat, dlon])
+
class first_k(object):
def __init__(self, k, id_latitude, id_longitude):
@@ -87,3 +123,18 @@ def add_destination(stream):
id_latitude = stream.sources.index('latitude')
id_longitude = stream.sources.index('longitude')
return Mapping(stream, destination(id_latitude, id_longitude), ('destination_latitude', 'destination_longitude'))
+
+
+class trip_filter(object):
+ def __init__(self, id_trip_id, exclude):
+ self.id_trip_id = id_trip_id
+ self.exclude = exclude
+ def __call__(self, data):
+ if data[self.id_trip_id] in self.exclude:
+ return False
+ else:
+ return True
+def filter_out_trips(exclude_trips, stream):
+ id_trip_id = stream.sources.index('trip_id')
+ return Filter(stream, trip_filter(id_trip_id, exclude_trips))
+