aboutsummaryrefslogtreecommitdiff
path: root/transformers.py
diff options
context:
space:
mode:
Diffstat (limited to 'transformers.py')
-rw-r--r--transformers.py125
1 files changed, 0 insertions, 125 deletions
diff --git a/transformers.py b/transformers.py
deleted file mode 100644
index 73e3868..0000000
--- a/transformers.py
+++ /dev/null
@@ -1,125 +0,0 @@
-from fuel.transformers import Transformer, Filter, Mapping
-import numpy
-import theano
-import random
-import data
-
-import datetime
-
-def at_least_k(k, v, pad_at_begin, is_longitude):
- if len(v) == 0:
- v = numpy.array([data.porto_center[1 if is_longitude else 0]], dtype=theano.config.floatX)
- if len(v) < k:
- if pad_at_begin:
- v = numpy.concatenate((numpy.full((k - len(v),), v[0]), v))
- else:
- v = numpy.concatenate((v, numpy.full((k - len(v),), v[-1])))
- return v
-
-
-class Select(Transformer):
- def __init__(self, data_stream, sources):
- super(Select, self).__init__(data_stream)
- self.ids = [data_stream.sources.index(source) for source in sources]
- self.sources=sources
-
- def get_data(self, request=None):
- if request is not None:
- raise ValueError
- data=next(self.child_epoch_iterator)
- return [data[id] for id in self.ids]
-
-class TaxiGenerateSplits(Transformer):
- def __init__(self, data_stream, max_splits=-1):
- super(TaxiGenerateSplits, self).__init__(data_stream)
- self.sources = data_stream.sources + ('destination_latitude', 'destination_longitude', 'time')
- self.max_splits = max_splits
- self.data = None
- self.splits = []
- self.isplit = 0
- self.id_latitude = data_stream.sources.index('latitude')
- self.id_longitude = data_stream.sources.index('longitude')
-
- def get_data(self, request=None):
- if request is not None:
- raise ValueError
- while self.isplit >= len(self.splits):
- self.data = next(self.child_epoch_iterator)
- self.splits = range(len(self.data[self.id_longitude]))
- random.shuffle(self.splits)
- if self.max_splits != -1 and len(self.splits) > self.max_splits:
- self.splits = self.splits[:self.max_splits]
- self.isplit = 0
-
- i = self.isplit
- self.isplit += 1
- n = self.splits[i]+1
-
- r = list(self.data)
-
- r[self.id_latitude] = numpy.array(r[self.id_latitude][:n], dtype=theano.config.floatX)
- r[self.id_longitude] = numpy.array(r[self.id_longitude][:n], dtype=theano.config.floatX)
-
- dlat = numpy.float32(self.data[self.id_latitude][-1])
- dlon = numpy.float32(self.data[self.id_longitude][-1])
-
- return tuple(r + [dlat, dlon, 15 * (len(self.data[self.id_longitude]) - 1)])
-
-class TaxiAddFirstK(Transformer):
- def __init__(self, k, stream):
- super(TaxiAddFirstK, self).__init__(stream)
- self.sources = stream.sources + ('first_k_latitude', 'first_k_longitude')
- self.id_latitude = stream.sources.index('latitude')
- self.id_longitude = stream.sources.index('longitude')
- self.k = k
- def get_data(self, request=None):
- if request is not None: raise ValueError
- data = next(self.child_epoch_iterator)
- first_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], False, False)[:self.k],
- dtype=theano.config.floatX),
- numpy.array(at_least_k(self.k, data[self.id_longitude], False, True)[:self.k],
- dtype=theano.config.floatX))
- return data + first_k
-
-class TaxiAddLastK(Transformer):
- def __init__(self, k, stream):
- super(TaxiAddLastK, self).__init__(stream)
- self.sources = stream.sources + ('last_k_latitude', 'last_k_longitude')
- self.id_latitude = stream.sources.index('latitude')
- self.id_longitude = stream.sources.index('longitude')
- self.k = k
- def get_data(self, request=None):
- if request is not None: raise ValueError
- data = next(self.child_epoch_iterator)
- last_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], True, False)[-self.k:],
- dtype=theano.config.floatX),
- numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:],
- dtype=theano.config.floatX))
- return data + last_k
-
-class TaxiAddDateTime(Transformer):
- def __init__(self, stream):
- super(TaxiAddDateTime, self).__init__(stream)
- self.sources = stream.sources + ('week_of_year', 'day_of_week', 'qhour_of_day')
- self.id_timestamp = stream.sources.index('timestamp')
- def get_data(self, request=None):
- if request is not None: raise ValueError
- data = next(self.child_epoch_iterator)
- ts = data[self.id_timestamp]
- date = datetime.datetime.utcfromtimestamp(ts)
- yearweek = date.isocalendar()[1] - 1
- info = ((51 if yearweek == 52 else yearweek), date.weekday(), date.hour * 4 + date.minute / 15)
- return data + info
-
-class TaxiExcludeTrips(Transformer):
- def __init__(self, exclude_list, stream):
- super(TaxiExcludeTrips, self).__init__(stream)
- self.id_trip_id = stream.sources.index('trip_id')
- self.exclude = {v: True for v in exclude_list}
- def get_data(self, request=None):
- if request is not None: raise ValueError
- while True:
- data = next(self.child_epoch_iterator)
- if not data[self.id_trip_id] in self.exclude: break
- return data
-