From b5584610a14578d0f3ebf9eea3067a0284f67288 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Tue, 26 Apr 2016 16:33:02 +0200 Subject: Xreg hyperparameter change --- model/lstm.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'model/lstm.py') diff --git a/model/lstm.py b/model/lstm.py index d928c88..10b090c 100644 --- a/model/lstm.py +++ b/model/lstm.py @@ -57,14 +57,16 @@ class Model(): states.append((init_cell, new_cells[-1, :, :])) if 'xreg' in p and p['xreg'] is not None: - n, s, w1, w2, w3 = p['xreg'] + n, s, w1, w2, w3, w4 = p['xreg'] cost_x1 = w1 * ((new_hidden.mean(axis=2) - s)**2).mean() cost_x2 = w2 * ((new_hidden.mean(axis=(0,1)) - s)**2).mean() cost_x3 = -w3 * abs(new_hidden - s).mean() + cost_x4 = w4 * abs(new_hidden[:-1,:,:]-new_hidden[1:,:,:]).mean() cost_x1.name = 'cost_x1_%d'%i cost_x2.name = 'cost_x2_%d'%i cost_x3.name = 'cost_x3_%d'%i - costs_xreg += [cost_x1, cost_x2, cost_x3] + cost_x4.name = 'cost_x4_%d'%i + costs_xreg += [cost_x1, cost_x2, cost_x3, cost_x4] dims.append(p['dim']) hidden.append(new_hidden) -- cgit v1.2.3