aboutsummaryrefslogblamecommitdiff
path: root/data/csv.py
blob: b6fe5b1cbc7d7cb16757cc0a717b89e550839167 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11

          
            
 


                                      



                                                              

                        

                                 


                                               
                                                          

                                        
                            

                                                                   

                                  


                                        






                                           

                           
                          

                           





                                           
                                           




                                            

                                    

                                           


                                             
                                   



                                                     
                                       
 
                   

                                            


                            
                                

                                                                                                     
                                                                               
                                                        
                                       
                                               



                                                                         
 
 
                                     

                                                                    
                                   

 


                                          
 

                                                              
                                                              
 
                                                                    
 



                                               
                                               
import ast
import csv
import numpy

from fuel.datasets import Dataset
from fuel.streams import DataStream
from fuel.iterator import DataIterator

import data
from data.hdf5 import origin_call_normalize, taxi_id_normalize


class TaxiData(Dataset):
    example_iteration_scheme=None

    class State:
        __slots__ = ('file', 'index', 'reader')

    def __init__(self, pathes, columns, has_header=False):
        if not isinstance(pathes, list):
            pathes=[pathes]
        assert len(pathes)>0
        self.columns=columns
        self.provides_sources = tuple(map(lambda x: x[0], columns))
        self.pathes=pathes
        self.has_header=has_header
        super(TaxiData, self).__init__()

    def open(self):
        state=self.State()
        state.file=open(self.pathes[0])
        state.index=0
        state.reader=csv.reader(state.file)
        if self.has_header:
            state.reader.next()
        return state

    def close(self, state):
        state.file.close()

    def reset(self, state):
        if state.index==0:
            state.file.seek(0)
        else:
            state.index=0
            state.file.close()
            state.file=open(self.pathes[0])
        state.reader=csv.reader(state.file)
        return state

    def get_data(self, state, request=None):
        if request is not None:
            raise ValueError
        try:
            line=state.reader.next()
        except (ValueError, StopIteration):
            # print state.index
            state.file.close()
            state.index+=1
            if state.index>=len(self.pathes):
                raise StopIteration
            state.file=open(self.pathes[state.index])
            state.reader=csv.reader(state.file)
            if self.has_header:
                state.reader.next()
            return self.get_data(state)

        values = []
        for _, constructor in self.columns:
            values.append(constructor(line))
        return tuple(values)

taxi_columns = [
    ("trip_id", lambda l: l[0]),
    ("call_type", lambda l: ord(l[1])-ord('A')),
    ("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else origin_call_normalize(int(l[2]))),
    ("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])),
    ("taxi_id", lambda l: taxi_id_normalize(int(l[4]))),
    ("timestamp", lambda l: int(l[5])),
    ("day_type", lambda l: ord(l[6])-ord('A')),
    ("missing_data", lambda l: l[7][0] == 'T'),
    ("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))),
    ("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))),
    ("latitude", lambda l: map(lambda p: p[1], ast.literal_eval(l[8]))),
]

taxi_columns_valid = taxi_columns + [
    ("destination_longitude", lambda l: numpy.float32(float(l[9]))),
    ("destination_latitude", lambda l: numpy.float32(float(l[10]))),
    ("time", lambda l: int(l[11])),
]

train_file="%s/train.csv" % data.path
valid_file="%s/valid2-cut.csv" % data.path
test_file="%s/test.csv" % data.path

train_data=TaxiData(train_file, taxi_columns, has_header=True)
valid_data = TaxiData(valid_file, taxi_columns_valid)
test_data = TaxiData(test_file, taxi_columns, has_header=True)

valid_trips = [l for l in open("%s/valid2-cut-ids.txt" % data.path)]

def train_it():
    return DataIterator(DataStream(train_data))

def test_it():
    return DataIterator(DataStream(valid_data))