diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-06-22 14:40:19 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-06-22 14:40:19 -0400 |
commit | bd08e452093bba68fe2d79b1e9da76488b203720 (patch) | |
tree | 48d82d8cf883ad642c483e2cad5cc707dd8c4694 | |
parent | ad5a03c6f60e5b2d543326bf8917b48e5b390b82 (diff) | |
download | taxi-bd08e452093bba68fe2d79b1e9da76488b203720.tar.gz taxi-bd08e452093bba68fe2d79b1e9da76488b203720.zip |
Update memory network
-rw-r--r-- | config/memory_network_1.py | 5 | ||||
-rw-r--r-- | data/cut.py | 5 | ||||
-rw-r--r-- | model/memory_network.py | 2 |
3 files changed, 7 insertions, 5 deletions
diff --git a/config/memory_network_1.py b/config/memory_network_1.py index 36d16ed..68e23bd 100644 --- a/config/memory_network_1.py +++ b/config/memory_network_1.py @@ -21,13 +21,13 @@ class MLPConfig(object): prefix_encoder = MLPConfig() prefix_encoder.dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings) -prefix_encoder.dim_hidden = [50] +prefix_encoder.dim_hidden = [100, 100, 100] prefix_encoder.weights_init = IsotropicGaussian(0.01) prefix_encoder.biases_init = Constant(0.001) candidate_encoder = MLPConfig() candidate_encoder.dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings) -candidate_encoder.dim_hidden = [50] +candidate_encoder.dim_hidden = [100, 100, 100] candidate_encoder.weights_init = IsotropicGaussian(0.01) candidate_encoder.biases_init = Constant(0.001) @@ -38,6 +38,7 @@ batch_size = 32 valid_set = 'cuts/test_times_0' max_splits = 1 +num_cuts = 1000 train_candidate_size = 1000 valid_candidate_size = 10000 diff --git a/data/cut.py b/data/cut.py index fc0b3f9..6e1e4e5 100644 --- a/data/cut.py +++ b/data/cut.py @@ -11,14 +11,15 @@ last_time = 1404172787 class TaxiTimeCutScheme(IterationScheme): - def __init__(self, dbfile=None, use_cuts=None): + def __init__(self, num_cuts=100, dbfile=None, use_cuts=None): + self.num_cuts = num_cuts self.dbfile = os.path.join(data.path, 'time_index.db') if dbfile == None else dbfile self.use_cuts = use_cuts def get_request_iterator(self): cuts = self.use_cuts if cuts == None: - cuts = [random.randrange(first_time, last_time) for _ in range(100)] + cuts = [random.randrange(first_time, last_time) for _ in range(self.num_cuts)] l = [] with sqlite3.connect(self.dbfile) as db: diff --git a/model/memory_network.py b/model/memory_network.py index 92e83e2..5acbfe3 100644 --- a/model/memory_network.py +++ b/model/memory_network.py @@ -88,7 +88,7 @@ class Stream(object): dataset = TaxiDataset('train') - prefix_stream = DataStream(dataset, iteration_scheme=TaxiTimeCutScheme()) + prefix_stream = DataStream(dataset, iteration_scheme=TaxiTimeCutScheme(self.config.num_cuts)) prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, valid_trips_ids) prefix_stream = transformers.TaxiGenerateSplits(prefix_stream, max_splits=self.config.max_splits) prefix_stream = transformers.taxi_add_datetime(prefix_stream) |