diff options
author | Alex Auvolat <alex@adnab.me> | 2016-04-26 16:33:02 +0200 |
---|---|---|
committer | Alex Auvolat <alex@adnab.me> | 2016-04-26 16:33:02 +0200 |
commit | b5584610a14578d0f3ebf9eea3067a0284f67288 (patch) | |
tree | cc23207cf82a6ee0d0bc121d1f0f6bd6e2a8e531 /model/lstm.py | |
parent | 760587b5d9771257160fac216dfcfff852de3ccc (diff) | |
download | text-rnn-b5584610a14578d0f3ebf9eea3067a0284f67288.tar.gz text-rnn-b5584610a14578d0f3ebf9eea3067a0284f67288.zip |
Xreg hyperparameter change
Diffstat (limited to 'model/lstm.py')
-rw-r--r-- | model/lstm.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/model/lstm.py b/model/lstm.py index d928c88..10b090c 100644 --- a/model/lstm.py +++ b/model/lstm.py @@ -57,14 +57,16 @@ class Model(): states.append((init_cell, new_cells[-1, :, :])) if 'xreg' in p and p['xreg'] is not None: - n, s, w1, w2, w3 = p['xreg'] + n, s, w1, w2, w3, w4 = p['xreg'] cost_x1 = w1 * ((new_hidden.mean(axis=2) - s)**2).mean() cost_x2 = w2 * ((new_hidden.mean(axis=(0,1)) - s)**2).mean() cost_x3 = -w3 * abs(new_hidden - s).mean() + cost_x4 = w4 * abs(new_hidden[:-1,:,:]-new_hidden[1:,:,:]).mean() cost_x1.name = 'cost_x1_%d'%i cost_x2.name = 'cost_x2_%d'%i cost_x3.name = 'cost_x3_%d'%i - costs_xreg += [cost_x1, cost_x2, cost_x3] + cost_x4.name = 'cost_x4_%d'%i + costs_xreg += [cost_x1, cost_x2, cost_x3, cost_x4] dims.append(p['dim']) hidden.append(new_hidden) |