import ast import csv import numpy import os from fuel.datasets import Dataset from fuel.streams import DataStream from fuel.iterator import DataIterator import data from data.hdf5 import origin_call_normalize, taxi_id_normalize class TaxiData(Dataset): example_iteration_scheme=None class State: __slots__ = ('file', 'index', 'reader') def __init__(self, pathes, columns, has_header=False): if not isinstance(pathes, list): pathes=[pathes] assert len(pathes)>0 self.columns=columns self.provides_sources = tuple(map(lambda x: x[0], columns)) self.pathes=pathes self.has_header=has_header super(TaxiData, self).__init__() def open(self): state=self.State() state.file=open(self.pathes[0]) state.index=0 state.reader=csv.reader(state.file) if self.has_header: state.reader.next() return state def close(self, state): state.file.close() def reset(self, state): if state.index==0: state.file.seek(0) else: state.index=0 state.file.close() state.file=open(self.pathes[0]) state.reader=csv.reader(state.file) return state def get_data(self, state, request=None): if request is not None: raise ValueError try: line=state.reader.next() except (ValueError, StopIteration): # print state.index state.file.close() state.index+=1 if state.index>=len(self.pathes): raise StopIteration state.file=open(self.pathes[state.index]) state.reader=csv.reader(state.file) if self.has_header: state.reader.next() return self.get_data(state) values = [] for _, constructor in self.columns: values.append(constructor(line)) return tuple(values) taxi_columns = [ ("trip_id", lambda l: l[0]), ("call_type", lambda l: ord(l[1])-ord('A')), ("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else origin_call_normalize(int(l[2]))), ("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])), ("taxi_id", lambda l: taxi_id_normalize(int(l[4]))), ("timestamp", lambda l: int(l[5])), ("day_type", lambda l: ord(l[6])-ord('A')), ("missing_data", lambda l: l[7][0] == 'T'), ("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))), ("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))), ("latitude", lambda l: map(lambda p: p[1], ast.literal_eval(l[8]))), ] taxi_columns_valid = taxi_columns + [ ("destination_longitude", lambda l: numpy.float32(float(l[9]))), ("destination_latitude", lambda l: numpy.float32(float(l[10]))), ("time", lambda l: int(l[11])), ] train_file = os.path.join(data.path, 'train.csv') valid_file = os.path.join(data.path, 'valid2-cut.csv') test_file = os.path.join(data.path, 'test.csv') train_data=TaxiData(train_file, taxi_columns, has_header=True) valid_data = TaxiData(valid_file, taxi_columns_valid) test_data = TaxiData(test_file, taxi_columns, has_header=True) with open(os.path.join(data.path, 'valid2-cut-ids.txt')) as f: valid_trips = [l for l in f] def train_it(): return DataIterator(DataStream(train_data)) def test_it(): return DataIterator(DataStream(valid_data))