aboutsummaryrefslogtreecommitdiff
path: root/data.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-04-24 17:32:57 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-04-24 17:32:57 -0400
commit0527e6e696fa1832d599473099429295dea31650 (patch)
tree38d3e4946236e01746eb4216f83e00a3ba5ea14f /data.py
parentbd2826df73554207c88c5918d86fd9707d9e3753 (diff)
downloadtaxi-0527e6e696fa1832d599473099429295dea31650.tar.gz
taxi-0527e6e696fa1832d599473099429295dea31650.zip
It kind of works (at least it does something now)
Diffstat (limited to 'data.py')
-rw-r--r--data.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/data.py b/data.py
index 4590a7b..e6b7cbf 100644
--- a/data.py
+++ b/data.py
@@ -7,6 +7,8 @@ from fuel.iterator import DataIterator
PREFIX="/data/lisatmp3/auvolat/taxikaggle"
+client_ids = {int(x): y+1 for y, x in enumerate(open(PREFIX+"/client_ids.txt"))}
+
class CallType(Enum):
CENTRAL = 0
STAND = 1
@@ -87,7 +89,7 @@ class TaxiData(Dataset):
state.index=0
state.file.close()
state.file=open(self.pathes[0])
- state.reader=csv.reader(state[0])
+ state.reader=csv.reader(state.file)
return state
def get_data(self, state, request=None):
@@ -95,7 +97,7 @@ class TaxiData(Dataset):
raise ValueError
try:
line=state.reader.next()
- except StopIteration:
+ except ValueError:
state.file.close()
state.index+=1
if state.index>=len(self.pathes):
@@ -104,10 +106,10 @@ class TaxiData(Dataset):
state.reader=csv.reader(state.file)
if self.has_header:
state.reader.next()
- line=state.reader.next()
+ return get_data(self, state)
line[1]=CallType.from_data(line[1]) # call_type
- line[2]=0 if line[2]=='' or line[2]=='NA' else int(line[2]) # origin_call
+ line[2]=0 if line[2]=='' or line[2]=='NA' else client_ids[int(line[2])] # origin_call
line[3]=0 if line[3]=='' or line[3]=='NA' else int(line[3]) # origin_stand
line[4]=int(line[4]) # taxi_id
line[5]=int(line[5]) # timestamp