aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-04-24 13:44:08 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-04-24 13:44:08 -0400
commit7b60ef60e03d01757dc5ea8ab3b87862f50e4ea4 (patch)
tree876d559dc09c57a32f938829e66090f1679f4ea1
parentf34ecb21c0ec362b560c2fe1c3afacc1e6dad998 (diff)
downloadtaxi-7b60ef60e03d01757dc5ea8ab3b87862f50e4ea4.tar.gz
taxi-7b60ef60e03d01757dc5ea8ab3b87862f50e4ea4.zip
Add data iterators
-rw-r--r--data.py100
1 files changed, 100 insertions, 0 deletions
diff --git a/data.py b/data.py
new file mode 100644
index 0000000..0493de7
--- /dev/null
+++ b/data.py
@@ -0,0 +1,100 @@
+import ast, csv
+import fuel
+from enum import Enum
+from fuel.datasets import Dataset
+from fuel.streams import DataStream
+from fuel.iterator import DataIterator
+
+PREFIX="/data/lisatmp3/auvolat/taxikaggle"
+
+class CallType(Enum):
+ CENTRAL = 0
+ STAND = 1
+ STREET = 2
+
+ @classmethod
+ def from_data(cls, val):
+ if val=='A':
+ return cls.CENTRAL
+ elif val=='B':
+ return cls.STAND
+ elif val=='C':
+ return cls.STREET
+
+ @classmethod
+ def to_data(cls, val):
+ if val==cls.CENTRAL:
+ return 'A'
+ elif val==cls.STAND:
+ return 'B'
+ elif val==cls.STREET:
+ return 'C'
+
+class DayType(Enum):
+ NORMAL = 0
+ HOLIDAY = 1
+ HOLIDAY_EVE = 2
+
+ @classmethod
+ def from_data(cls, val):
+ if val=='A':
+ return cls.NORMAL
+ elif val=='B':
+ return cls.HOLIDAY
+ elif val=='C':
+ return cls.HOLIDAY_EVE
+
+ @classmethod
+ def to_data(cls, val):
+ if val==cls.NORMAL:
+ return 'A'
+ elif val==cls.HOLIDAY:
+ return 'B'
+ elif val==cls.HOLIDAY_EVE:
+ return 'C'
+
+class TaxiData(Dataset):
+ provides_sources= ("trip_id","call_type","origin_call","origin_stand","taxi_id","timestamp","day_type","missing_data","polyline")
+
+ example_iteration_scheme=None
+
+ def __init__(self, path):
+ self.path=path
+ super(TaxiData, self).__init__()
+
+ def open(self):
+ file=open(self.path)
+ reader=csv.reader(file)
+ reader.next() # Skip header
+ return (file, reader)
+
+ def close(self, state):
+ state[0].close()
+
+ def reset(self, state):
+ state[0].seek(0)
+ state[1]=csv.reader(state[0])
+ return state
+
+ def get_data(self, state, request=None):
+ if request is not None:
+ raise ValueError
+ line=state[1].next()
+ line[1]=CallType.from_data(line[1]) # call_type
+ line[2]=0 if line[2]=='' or line[2]=='NA' else int(line[2]) # origin_call
+ line[3]=0 if line[3]=='' or line[3]=='NA' else int(line[3]) # origin_stand
+ line[4]=int(line[4]) # taxi_id
+ line[5]=int(line[5]) # timestamp
+ line[6]=DayType.from_data(line[6]) # day_type
+ line[7]=line[7][0]=='T' # missing_data
+ line[8]=map(tuple, ast.literal_eval(line[8])) # polyline
+ return tuple(line)
+
+train_data=TaxiData(PREFIX+'/train.csv')
+test_data=TaxiData(PREFIX+'/test.csv')
+
+def train_it():
+ return DataIterator(DataStream(train_data))
+
+def test_it():
+ return DataIterator(DataStream(test_data))