aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--data.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/data.py b/data.py
index e6b7cbf..5e3b409 100644
--- a/data.py
+++ b/data.py
@@ -1,13 +1,17 @@
import ast, csv
+import socket
import fuel
from enum import Enum
from fuel.datasets import Dataset
from fuel.streams import DataStream
from fuel.iterator import DataIterator
-PREFIX="/data/lisatmp3/auvolat/taxikaggle"
+if socket.gethostname() == "adeb.laptop":
+ DATA_PATH = "/Users/adeb/data/taxi"
+else:
+ PREFIX="/data/lisatmp3/auvolat/taxikaggle"
-client_ids = {int(x): y+1 for y, x in enumerate(open(PREFIX+"/client_ids.txt"))}
+client_ids = {int(x): y+1 for y, x in enumerate(open(DATA_PATH+"/client_ids.txt"))}
class CallType(Enum):
CENTRAL = 0
@@ -97,7 +101,8 @@ class TaxiData(Dataset):
raise ValueError
try:
line=state.reader.next()
- except ValueError:
+ except StopIteration:
+ print state.index
state.file.close()
state.index+=1
if state.index>=len(self.pathes):
@@ -106,7 +111,7 @@ class TaxiData(Dataset):
state.reader=csv.reader(state.file)
if self.has_header:
state.reader.next()
- return get_data(self, state)
+ return self.get_data(state)
line[1]=CallType.from_data(line[1]) # call_type
line[2]=0 if line[2]=='' or line[2]=='NA' else client_ids[int(line[2])] # origin_call
@@ -118,8 +123,8 @@ class TaxiData(Dataset):
line[8]=map(tuple, ast.literal_eval(line[8])) # polyline
return tuple(line)
-train_files=["%s/split/train-%02d.csv" % (PREFIX, i) for i in range(100)]
-valid_files=["%s/split/valid.csv" % (PREFIX,)]
+train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)]
+valid_files=["%s/split/valid.csv" % (DATA_PATH,)]
train_data=TaxiData(train_files)
valid_data=TaxiData(valid_files)
@@ -127,4 +132,4 @@ def train_it():
return DataIterator(DataStream(train_data))
def test_it():
- return DataIterator(DataStream(test_data))
+ return DataIterator(DataStream(valid_data))