diff options
author | AdeB <adbrebs@gmail.com> | 2015-04-25 10:06:40 -0400 |
---|---|---|
committer | AdeB <adbrebs@gmail.com> | 2015-04-25 10:06:40 -0400 |
commit | 676af1086b141a7803626b040e7da03526b95406 (patch) | |
tree | 473b137ab2ff2ec1ae533d24a5ff5d1daba6eed7 | |
parent | 0527e6e696fa1832d599473099429295dea31650 (diff) | |
download | taxi-676af1086b141a7803626b040e7da03526b95406.tar.gz taxi-676af1086b141a7803626b040e7da03526b95406.zip |
Correct a few typos.
-rw-r--r-- | data.py | 19 |
1 files changed, 12 insertions, 7 deletions
@@ -1,13 +1,17 @@ import ast, csv +import socket import fuel from enum import Enum from fuel.datasets import Dataset from fuel.streams import DataStream from fuel.iterator import DataIterator -PREFIX="/data/lisatmp3/auvolat/taxikaggle" +if socket.gethostname() == "adeb.laptop": + DATA_PATH = "/Users/adeb/data/taxi" +else: + PREFIX="/data/lisatmp3/auvolat/taxikaggle" -client_ids = {int(x): y+1 for y, x in enumerate(open(PREFIX+"/client_ids.txt"))} +client_ids = {int(x): y+1 for y, x in enumerate(open(DATA_PATH+"/client_ids.txt"))} class CallType(Enum): CENTRAL = 0 @@ -97,7 +101,8 @@ class TaxiData(Dataset): raise ValueError try: line=state.reader.next() - except ValueError: + except StopIteration: + print state.index state.file.close() state.index+=1 if state.index>=len(self.pathes): @@ -106,7 +111,7 @@ class TaxiData(Dataset): state.reader=csv.reader(state.file) if self.has_header: state.reader.next() - return get_data(self, state) + return self.get_data(state) line[1]=CallType.from_data(line[1]) # call_type line[2]=0 if line[2]=='' or line[2]=='NA' else client_ids[int(line[2])] # origin_call @@ -118,8 +123,8 @@ class TaxiData(Dataset): line[8]=map(tuple, ast.literal_eval(line[8])) # polyline return tuple(line) -train_files=["%s/split/train-%02d.csv" % (PREFIX, i) for i in range(100)] -valid_files=["%s/split/valid.csv" % (PREFIX,)] +train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)] +valid_files=["%s/split/valid.csv" % (DATA_PATH,)] train_data=TaxiData(train_files) valid_data=TaxiData(valid_files) @@ -127,4 +132,4 @@ def train_it(): return DataIterator(DataStream(train_data)) def test_it(): - return DataIterator(DataStream(test_data)) + return DataIterator(DataStream(valid_data)) |