aboutsummaryrefslogtreecommitdiff
path: root/data.py
diff options
context:
space:
mode:
Diffstat (limited to 'data.py')
-rw-r--r--data.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/data.py b/data.py
index 4590a7b..e6b7cbf 100644
--- a/data.py
+++ b/data.py
@@ -7,6 +7,8 @@ from fuel.iterator import DataIterator
PREFIX="/data/lisatmp3/auvolat/taxikaggle"
+client_ids = {int(x): y+1 for y, x in enumerate(open(PREFIX+"/client_ids.txt"))}
+
class CallType(Enum):
CENTRAL = 0
STAND = 1
@@ -87,7 +89,7 @@ class TaxiData(Dataset):
state.index=0
state.file.close()
state.file=open(self.pathes[0])
- state.reader=csv.reader(state[0])
+ state.reader=csv.reader(state.file)
return state
def get_data(self, state, request=None):
@@ -95,7 +97,7 @@ class TaxiData(Dataset):
raise ValueError
try:
line=state.reader.next()
- except StopIteration:
+ except ValueError:
state.file.close()
state.index+=1
if state.index>=len(self.pathes):
@@ -104,10 +106,10 @@ class TaxiData(Dataset):
state.reader=csv.reader(state.file)
if self.has_header:
state.reader.next()
- line=state.reader.next()
+ return get_data(self, state)
line[1]=CallType.from_data(line[1]) # call_type
- line[2]=0 if line[2]=='' or line[2]=='NA' else int(line[2]) # origin_call
+ line[2]=0 if line[2]=='' or line[2]=='NA' else client_ids[int(line[2])] # origin_call
line[3]=0 if line[3]=='' or line[3]=='NA' else int(line[3]) # origin_stand
line[4]=int(line[4]) # taxi_id
line[5]=int(line[5]) # timestamp