aboutsummaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
authorAdeB <adbrebs@gmail.com>2015-05-05 22:15:29 -0400
committerAdeB <adbrebs@gmail.com>2015-05-05 22:15:29 -0400
commitc29a0d3f22134a8d1f5d557b325f6779c5961546 (patch)
tree6fa431d9b3595b5d2d11089920aa07ea43172d90 /data
parentf4d3ee6449217535bdbe19ac9c5fdd825d71b0d3 (diff)
parent1f2ff96e6480a62089fcac35154a956c218ed678 (diff)
downloadtaxi-c29a0d3f22134a8d1f5d557b325f6779c5961546.tar.gz
taxi-c29a0d3f22134a8d1f5d557b325f6779c5961546.zip
Merge branch 'master' of github.com:adbrebs/taxi
Diffstat (limited to 'data')
-rw-r--r--data/__init__.py31
-rw-r--r--data/csv.py107
-rwxr-xr-xdata/csv_to_hdf5.py127
-rw-r--r--data/cuts/__init__.py0
-rw-r--r--data/cuts/test_times_0.py8
-rw-r--r--data/hdf5.py61
-rwxr-xr-xdata/init_valid.py61
-rwxr-xr-xdata/make_valid_cut.py72
-rw-r--r--data/transformers.py127
9 files changed, 594 insertions, 0 deletions
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000..1278e0b
--- /dev/null
+++ b/data/__init__.py
@@ -0,0 +1,31 @@
+import os
+
+import h5py
+import numpy
+import theano
+
+
+path = os.environ.get('TAXI_PATH', '/data/lisatmp3/auvolat/taxikaggle')
+Polyline = h5py.special_dtype(vlen=theano.config.floatX)
+
+
+# `wc -l test.csv` - 1 # Minus 1 to ignore the header
+test_size = 320
+
+# `wc -l train.csv` - 1
+train_size = 1710670
+
+# `wc -l metaData_taxistandsID_name_GPSlocation.csv`
+stands_size = 64 # include 0 ("no origin_stands")
+
+# `cut -d, -f 5 train.csv test.csv | sort -u | wc -l` - 1
+taxi_id_size = 448
+
+# `cut -d, -f 3 train.csv test.csv | sort -u | wc -l` - 2
+origin_call_size = 57125 # include 0 ("no origin_call")
+
+# As printed by csv_to_hdf5.py
+origin_call_train_size = 57106
+
+train_gps_mean = numpy.array([41.1573, -8.61612], dtype=theano.config.floatX)
+train_gps_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=theano.config.floatX))
diff --git a/data/csv.py b/data/csv.py
new file mode 100644
index 0000000..b6fe5b1
--- /dev/null
+++ b/data/csv.py
@@ -0,0 +1,107 @@
+import ast
+import csv
+import numpy
+
+from fuel.datasets import Dataset
+from fuel.streams import DataStream
+from fuel.iterator import DataIterator
+
+import data
+from data.hdf5 import origin_call_normalize, taxi_id_normalize
+
+
+class TaxiData(Dataset):
+ example_iteration_scheme=None
+
+ class State:
+ __slots__ = ('file', 'index', 'reader')
+
+ def __init__(self, pathes, columns, has_header=False):
+ if not isinstance(pathes, list):
+ pathes=[pathes]
+ assert len(pathes)>0
+ self.columns=columns
+ self.provides_sources = tuple(map(lambda x: x[0], columns))
+ self.pathes=pathes
+ self.has_header=has_header
+ super(TaxiData, self).__init__()
+
+ def open(self):
+ state=self.State()
+ state.file=open(self.pathes[0])
+ state.index=0
+ state.reader=csv.reader(state.file)
+ if self.has_header:
+ state.reader.next()
+ return state
+
+ def close(self, state):
+ state.file.close()
+
+ def reset(self, state):
+ if state.index==0:
+ state.file.seek(0)
+ else:
+ state.index=0
+ state.file.close()
+ state.file=open(self.pathes[0])
+ state.reader=csv.reader(state.file)
+ return state
+
+ def get_data(self, state, request=None):
+ if request is not None:
+ raise ValueError
+ try:
+ line=state.reader.next()
+ except (ValueError, StopIteration):
+ # print state.index
+ state.file.close()
+ state.index+=1
+ if state.index>=len(self.pathes):
+ raise StopIteration
+ state.file=open(self.pathes[state.index])
+ state.reader=csv.reader(state.file)
+ if self.has_header:
+ state.reader.next()
+ return self.get_data(state)
+
+ values = []
+ for _, constructor in self.columns:
+ values.append(constructor(line))
+ return tuple(values)
+
+taxi_columns = [
+ ("trip_id", lambda l: l[0]),
+ ("call_type", lambda l: ord(l[1])-ord('A')),
+ ("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else origin_call_normalize(int(l[2]))),
+ ("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])),
+ ("taxi_id", lambda l: taxi_id_normalize(int(l[4]))),
+ ("timestamp", lambda l: int(l[5])),
+ ("day_type", lambda l: ord(l[6])-ord('A')),
+ ("missing_data", lambda l: l[7][0] == 'T'),
+ ("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))),
+ ("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))),
+ ("latitude", lambda l: map(lambda p: p[1], ast.literal_eval(l[8]))),
+]
+
+taxi_columns_valid = taxi_columns + [
+ ("destination_longitude", lambda l: numpy.float32(float(l[9]))),
+ ("destination_latitude", lambda l: numpy.float32(float(l[10]))),
+ ("time", lambda l: int(l[11])),
+]
+
+train_file="%s/train.csv" % data.path
+valid_file="%s/valid2-cut.csv" % data.path
+test_file="%s/test.csv" % data.path
+
+train_data=TaxiData(train_file, taxi_columns, has_header=True)
+valid_data = TaxiData(valid_file, taxi_columns_valid)
+test_data = TaxiData(test_file, taxi_columns, has_header=True)
+
+valid_trips = [l for l in open("%s/valid2-cut-ids.txt" % data.path)]
+
+def train_it():
+ return DataIterator(DataStream(train_data))
+
+def test_it():
+ return DataIterator(DataStream(valid_data))
diff --git a/data/csv_to_hdf5.py b/data/csv_to_hdf5.py
new file mode 100755
index 0000000..17217f3
--- /dev/null
+++ b/data/csv_to_hdf5.py
@@ -0,0 +1,127 @@
+#!/usr/bin/env python
+
+import ast
+import csv
+import os
+import sys
+
+import h5py
+import numpy
+import theano
+from fuel.converters.base import fill_hdf5_file
+
+import data
+
+
+taxi_id_dict = {}
+origin_call_dict = {0: 0}
+
+def get_unique_taxi_id(val):
+ if val in taxi_id_dict:
+ return taxi_id_dict[val]
+ else:
+ taxi_id_dict[val] = len(taxi_id_dict)
+ return len(taxi_id_dict) - 1
+
+def get_unique_origin_call(val):
+ if val in origin_call_dict:
+ return origin_call_dict[val]
+ else:
+ origin_call_dict[val] = len(origin_call_dict)
+ return len(origin_call_dict) - 1
+
+def read_stands(input_directory, h5file):
+ stands_name = numpy.empty(shape=(data.stands_size,), dtype=('a', 24))
+ stands_latitude = numpy.empty(shape=(data.stands_size,), dtype=theano.config.floatX)
+ stands_longitude = numpy.empty(shape=(data.stands_size,), dtype=theano.config.floatX)
+ stands_name[0] = 'None'
+ stands_latitude[0] = stands_longitude[0] = 0
+ with open(os.path.join(input_directory, 'metaData_taxistandsID_name_GPSlocation.csv'), 'r') as f:
+ reader = csv.reader(f)
+ reader.next() # header
+ for line in reader:
+ id = int(line[0])
+ stands_name[id] = line[1]
+ stands_latitude[id] = float(line[2])
+ stands_longitude[id] = float(line[3])
+ return (('stands', 'stands_name', stands_name),
+ ('stands', 'stands_latitude', stands_latitude),
+ ('stands', 'stands_longitude', stands_longitude))
+
+def read_taxis(input_directory, h5file, dataset):
+ print >> sys.stderr, 'read %s: begin' % dataset
+ size=getattr(data, '%s_size'%dataset)
+ trip_id = numpy.empty(shape=(size,), dtype='S19')
+ call_type = numpy.empty(shape=(size,), dtype=numpy.uint8)
+ origin_call = numpy.empty(shape=(size,), dtype=numpy.uint32)
+ origin_stand = numpy.empty(shape=(size,), dtype=numpy.uint8)
+ taxi_id = numpy.empty(shape=(size,), dtype=numpy.uint16)
+ timestamp = numpy.empty(shape=(size,), dtype=numpy.uint32)
+ day_type = numpy.empty(shape=(size,), dtype=numpy.uint8)
+ missing_data = numpy.empty(shape=(size,), dtype=numpy.bool)
+ latitude = numpy.empty(shape=(size,), dtype=data.Polyline)
+ longitude = numpy.empty(shape=(size,), dtype=data.Polyline)
+ with open(os.path.join(input_directory, '%s.csv'%dataset), 'r') as f:
+ reader = csv.reader(f)
+ reader.next() # header
+ id=0
+ for line in reader:
+ if id%10000==0 and id!=0:
+ print >> sys.stderr, 'read %s: %d done' % (dataset, id)
+ trip_id[id] = line[0]
+ call_type[id] = ord(line[1][0]) - ord('A')
+ origin_call[id] = 0 if line[2]=='NA' or line[2]=='' else get_unique_origin_call(int(line[2]))
+ origin_stand[id] = 0 if line[3]=='NA' or line[3]=='' else int(line[3])
+ taxi_id[id] = get_unique_taxi_id(int(line[4]))
+ timestamp[id] = int(line[5])
+ day_type[id] = ord(line[6][0]) - ord('A')
+ missing_data[id] = line[7][0] == 'T'
+ polyline = ast.literal_eval(line[8])
+ latitude[id] = numpy.array([point[1] for point in polyline], dtype=theano.config.floatX)
+ longitude[id] = numpy.array([point[0] for point in polyline], dtype=theano.config.floatX)
+ id+=1
+ splits = ()
+ print >> sys.stderr, 'read %s: writing' % dataset
+ for name in ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude']:
+ splits += ((dataset, name, locals()[name]),)
+ print >> sys.stderr, 'read %s: end' % dataset
+ return splits
+
+def unique(h5file):
+ unique_taxi_id = numpy.empty(shape=(data.taxi_id_size,), dtype=numpy.uint32)
+ assert len(taxi_id_dict) == data.taxi_id_size
+ for k, v in taxi_id_dict.items():
+ unique_taxi_id[v] = k
+
+ unique_origin_call = numpy.empty(shape=(data.origin_call_size,), dtype=numpy.uint32)
+ assert len(origin_call_dict) == data.origin_call_size
+ for k, v in origin_call_dict.items():
+ unique_origin_call[v] = k
+
+ return (('unique_taxi_id', 'unique_taxi_id', unique_taxi_id),
+ ('unique_origin_call', 'unique_origin_call', unique_origin_call))
+
+def convert(input_directory, save_path):
+ h5file = h5py.File(save_path, 'w')
+ split = ()
+ split += read_stands(input_directory, h5file)
+ split += read_taxis(input_directory, h5file, 'train')
+ print 'First origin_call not present in training set: ', len(origin_call_dict)
+ split += read_taxis(input_directory, h5file, 'test')
+ split += unique(h5file)
+
+ fill_hdf5_file(h5file, split)
+
+ for name in ['stands_name', 'stands_latitude', 'stands_longitude', 'unique_taxi_id', 'unique_origin_call']:
+ h5file[name].dims[0].label = 'index'
+ for name in ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude']:
+ h5file[name].dims[0].label = 'batch'
+
+ h5file.flush()
+ h5file.close()
+
+if __name__ == '__main__':
+ if len(sys.argv) != 3:
+ print >> sys.stderr, 'Usage: %s download_dir output_file' % sys.argv[0]
+ sys.exit(1)
+ convert(sys.argv[1], sys.argv[2])
diff --git a/data/cuts/__init__.py b/data/cuts/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/data/cuts/__init__.py
diff --git a/data/cuts/test_times_0.py b/data/cuts/test_times_0.py
new file mode 100644
index 0000000..b590072
--- /dev/null
+++ b/data/cuts/test_times_0.py
@@ -0,0 +1,8 @@
+# Cuts of the test set minus 1 year
+cuts = [
+ 1376503200, # 2013-08-14 18:00
+ 1380616200, # 2013-10-01 08:30
+ 1381167900, # 2013-10-07 17:45
+ 1383364800, # 2013-11-02 04:00
+ 1387722600 # 2013-12-22 14:30
+]
diff --git a/data/hdf5.py b/data/hdf5.py
new file mode 100644
index 0000000..d848023
--- /dev/null
+++ b/data/hdf5.py
@@ -0,0 +1,61 @@
+import os
+
+import h5py
+from fuel.datasets import H5PYDataset
+from fuel.iterator import DataIterator
+from fuel.schemes import SequentialExampleScheme
+from fuel.streams import DataStream
+
+import data
+
+
+class TaxiDataset(H5PYDataset):
+ def __init__(self, which_set, filename='data.hdf5', **kwargs):
+ self.filename = filename
+ kwargs.setdefault('load_in_memory', True)
+ super(TaxiDataset, self).__init__(self.data_path, which_set, **kwargs)
+
+ @property
+ def data_path(self):
+ return os.path.join(data.path, self.filename)
+
+class TaxiStream(DataStream):
+ def __init__(self, which_set, filename='data.hdf5', iteration_scheme=None, **kwargs):
+ dataset = TaxiDataset(which_set, filename, **kwargs)
+ if iteration_scheme is None:
+ iteration_scheme = SequentialExampleScheme(dataset.num_examples)
+ super(TaxiStream, self).__init__(dataset, iteration_scheme=iteration_scheme)
+
+_origin_calls = None
+_reverse_origin_calls = None
+
+def origin_call_unnormalize(x):
+ if _origin_calls is None:
+ _origin_calls = h5py.File(os.path.join(data.path, 'data.hdf5'), 'r')['unique_origin_call']
+ return _origin_calls[x]
+
+def origin_call_normalize(x):
+ if _reverse_origin_calls is None:
+ origin_call_unnormalize(0)
+ _reverse_origin_calls = { _origin_calls[i]: i for i in range(_origin_calls.shape[0]) }
+ return _reverse_origin_calls[x]
+
+_taxi_ids = None
+_reverse_taxi_ids = None
+
+def taxi_id_unnormalize(x):
+ if _taxi_ids is None:
+ _taxi_ids = h5py.File(os.path.join(data.path, 'data.hdf5'), 'r')['unique_taxi_id']
+ return _taxi_ids[x]
+
+def taxi_id_normalize(x):
+ if _reverse_taxi_ids is None:
+ taxi_id_unnormalize(0)
+ _reverse_taxi_ids = { _taxi_ids[i]: i for i in range(_taxi_ids.shape[0]) }
+ return _reverse_taxi_ids[x]
+
+def taxi_it(which_set, filename='data.hdf5', sub=None, as_dict=True):
+ dataset = TaxiDataset(which_set, filename)
+ if sub is None:
+ sub = xrange(dataset.num_examples)
+ return DataIterator(DataStream(dataset), iter(sub), as_dict)
diff --git a/data/init_valid.py b/data/init_valid.py
new file mode 100755
index 0000000..14a854c
--- /dev/null
+++ b/data/init_valid.py
@@ -0,0 +1,61 @@
+#!/usr/bin/env python
+# Initialize the valid hdf5
+
+import os
+import sys
+
+import h5py
+import numpy
+import theano
+
+import data
+
+
+_fields = {
+ 'trip_id': 'S19',
+ 'call_type': numpy.uint8,
+ 'origin_call': numpy.uint32,
+ 'origin_stand': numpy.uint8,
+ 'taxi_id': numpy.uint16,
+ 'timestamp': numpy.uint32,
+ 'day_type': numpy.uint8,
+ 'missing_data': numpy.bool,
+ 'latitude': data.Polyline,
+ 'longitude': data.Polyline,
+ 'destination_latitude': theano.config.floatX,
+ 'destination_longitude': theano.config.floatX,
+ 'travel_time': numpy.uint32,
+}
+
+
+def init_valid(path):
+ h5file = h5py.File(path, 'w')
+
+ for k, v in _fields.items():
+ h5file.create_dataset(k, (0,), dtype=v, maxshape=(None,))
+
+ split_array = numpy.empty(len(_fields), dtype=numpy.dtype([
+ ('split', 'a', 64),
+ ('source', 'a', 21),
+ ('start', numpy.int64, 1),
+ ('stop', numpy.int64, 1),
+ ('available', numpy.bool, 1),
+ ('comment', 'a', 1)]))
+
+ split_array[:]['split'] = 'dummy'.encode('utf8')
+ for (i, k) in enumerate(_fields.keys()):
+ split_array[i] = k.encode('utf8')
+ split_array[:]['start'] = 0
+ split_array[:]['stop'] = 0
+ split_array[:]['available'] = False
+ split_array[:]['comment'] = '.'.encode('utf8')
+ h5file.attrs['split'] = split_array
+
+ h5file.flush()
+ h5file.close()
+
+if __name__ == '__main__':
+ if len(sys.argv) > 2:
+ print >> sys.stderr, 'Usage: %s [file]' % sys.argv[0]
+ sys.exit(1)
+ init_valid(sys.argv[1] if len(sys.argv) == 2 else os.path.join(data.path, 'valid.hdf5'))
diff --git a/data/make_valid_cut.py b/data/make_valid_cut.py
new file mode 100755
index 0000000..d5be083
--- /dev/null
+++ b/data/make_valid_cut.py
@@ -0,0 +1,72 @@
+#!/usr/bin/env python
+# Make a valid dataset by cutting the training set at specified timestamps
+
+import os
+import sys
+import importlib
+
+import h5py
+import numpy
+
+import data
+from data.hdf5 import taxi_it
+
+
+_fields = ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude', 'destination_latitude', 'destination_longitude', 'travel_time']
+
+def make_valid(cutfile, outpath):
+ cuts = importlib.import_module('.%s' % cutfile, 'data.cuts').cuts
+
+ valid = []
+
+ for line in taxi_it('train'):
+ time = line['timestamp']
+ latitude = line['latitude']
+ longitude = line['longitude']
+
+ if len(latitude) == 0:
+ continue
+
+ for ts in cuts:
+ if time <= ts and time + 15 * (len(latitude) - 1) >= ts:
+ # keep it
+ n = (ts - time) / 15 + 1
+ line.update({
+ 'latitude': latitude[:n],
+ 'longitude': longitude[:n],
+ 'destination_latitude': latitude[-1],
+ 'destination_longitude': longitude[-1],
+ 'travel_time': 15 * (len(latitude)-1)
+ })
+ valid.append(line)
+
+ file = h5py.File(outpath, 'a')
+ clen = file['trip_id'].shape[0]
+ alen = len(valid)
+ for field in _fields:
+ dset = file[field]
+ dset.resize((clen + alen,))
+ for i in xrange(alen):
+ dset[clen + i] = valid[i][field]
+
+ splits = file.attrs['split']
+ slen = splits.shape[0]
+ splits = numpy.resize(splits, (slen+len(_fields),))
+ for (i, field) in enumerate(_fields):
+ splits[slen+i]['split'] = ('cuts/%s' % cutfile).encode('utf8')
+ splits[slen+i]['source'] = field.encode('utf8')
+ splits[slen+i]['start'] = clen
+ splits[slen+i]['stop'] = alen
+ splits[slen+i]['available'] = True
+ splits[slen+i]['comment'] = '.'
+ file.attrs['split'] = splits
+
+ file.flush()
+ file.close()
+
+if __name__ == '__main__':
+ if len(sys.argv) < 2 or len(sys.argv) > 3:
+ print >> sys.stderr, 'Usage: %s cutfile [outfile]' % sys.argv[0]
+ sys.exit(1)
+ outpath = os.path.join(data.path, 'valid.hdf5') if len(sys.argv) < 3 else sys.argv[2]
+ make_valid(sys.argv[1], outpath)
diff --git a/data/transformers.py b/data/transformers.py
new file mode 100644
index 0000000..1cc4834
--- /dev/null
+++ b/data/transformers.py
@@ -0,0 +1,127 @@
+import datetime
+import random
+
+import numpy
+import theano
+from fuel.transformers import Transformer
+
+import data
+
+
+def at_least_k(k, v, pad_at_begin, is_longitude):
+ if len(v) == 0:
+ v = numpy.array([data.porto_center[1 if is_longitude else 0]], dtype=theano.config.floatX)
+ if len(v) < k:
+ if pad_at_begin:
+ v = numpy.concatenate((numpy.full((k - len(v),), v[0]), v))
+ else:
+ v = numpy.concatenate((v, numpy.full((k - len(v),), v[-1])))
+ return v
+
+
+class Select(Transformer):
+ def __init__(self, data_stream, sources):
+ super(Select, self).__init__(data_stream)
+ self.ids = [data_stream.sources.index(source) for source in sources]
+ self.sources=sources
+
+ def get_data(self, request=None):
+ if request is not None:
+ raise ValueError
+ data=next(self.child_epoch_iterator)
+ return [data[id] for id in self.ids]
+
+class TaxiGenerateSplits(Transformer):
+ def __init__(self, data_stream, max_splits=-1):
+ super(TaxiGenerateSplits, self).__init__(data_stream)
+ self.sources = data_stream.sources + ('destination_latitude', 'destination_longitude', 'time')
+ self.max_splits = max_splits
+ self.data = None
+ self.splits = []
+ self.isplit = 0
+ self.id_latitude = data_stream.sources.index('latitude')
+ self.id_longitude = data_stream.sources.index('longitude')
+
+ def get_data(self, request=None):
+ if request is not None:
+ raise ValueError
+ while self.isplit >= len(self.splits):
+ self.data = next(self.child_epoch_iterator)
+ self.splits = range(len(self.data[self.id_longitude]))
+ random.shuffle(self.splits)
+ if self.max_splits != -1 and len(self.splits) > self.max_splits:
+ self.splits = self.splits[:self.max_splits]
+ self.isplit = 0
+
+ i = self.isplit
+ self.isplit += 1
+ n = self.splits[i]+1
+
+ r = list(self.data)
+
+ r[self.id_latitude] = numpy.array(r[self.id_latitude][:n], dtype=theano.config.floatX)
+ r[self.id_longitude] = numpy.array(r[self.id_longitude][:n], dtype=theano.config.floatX)
+
+ dlat = numpy.float32(self.data[self.id_latitude][-1])
+ dlon = numpy.float32(self.data[self.id_longitude][-1])
+
+ return tuple(r + [dlat, dlon, 15 * (len(self.data[self.id_longitude]) - 1)])
+
+class TaxiAddFirstK(Transformer):
+ def __init__(self, k, stream):
+ super(TaxiAddFirstK, self).__init__(stream)
+ self.sources = stream.sources + ('first_k_latitude', 'first_k_longitude')
+ self.id_latitude = stream.sources.index('latitude')
+ self.id_longitude = stream.sources.index('longitude')
+ self.k = k
+ def get_data(self, request=None):
+ if request is not None: raise ValueError
+ data = next(self.child_epoch_iterator)
+ first_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], False, False)[:self.k],
+ dtype=theano.config.floatX),
+ numpy.array(at_least_k(self.k, data[self.id_longitude], False, True)[:self.k],
+ dtype=theano.config.floatX))
+ return data + first_k
+
+class TaxiAddLastK(Transformer):
+ def __init__(self, k, stream):
+ super(TaxiAddLastK, self).__init__(stream)
+ self.sources = stream.sources + ('last_k_latitude', 'last_k_longitude')
+ self.id_latitude = stream.sources.index('latitude')
+ self.id_longitude = stream.sources.index('longitude')
+ self.k = k
+ def get_data(self, request=None):
+ if request is not None: raise ValueError
+ data = next(self.child_epoch_iterator)
+ last_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], True, False)[-self.k:],
+ dtype=theano.config.floatX),
+ numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:],
+ dtype=theano.config.floatX))
+ return data + last_k
+
+class TaxiAddDateTime(Transformer):
+ def __init__(self, stream):
+ super(TaxiAddDateTime, self).__init__(stream)
+ self.sources = stream.sources + ('week_of_year', 'day_of_week', 'qhour_of_day')
+ self.id_timestamp = stream.sources.index('timestamp')
+ def get_data(self, request=None):
+ if request is not None: raise ValueError
+ data = next(self.child_epoch_iterator)
+ ts = data[self.id_timestamp]
+ date = datetime.datetime.utcfromtimestamp(ts)
+ yearweek = date.isocalendar()[1] - 1
+ info = ((51 if yearweek == 52 else yearweek), date.weekday(), date.hour * 4 + date.minute / 15)
+ return data + info
+
+class TaxiExcludeTrips(Transformer):
+ def __init__(self, exclude_list, stream):
+ super(TaxiExcludeTrips, self).__init__(stream)
+ self.id_trip_id = stream.sources.index('trip_id')
+ self.exclude = {v: True for v in exclude_list}
+ def get_data(self, request=None):
+ if request is not None: raise ValueError
+ while True:
+ data = next(self.child_epoch_iterator)
+ if not data[self.id_trip_id] in self.exclude: break
+ return data
+