diff options
author | Alex Auvolat <alex@adnab.me> | 2023-10-18 21:53:04 +0200 |
---|---|---|
committer | Alex Auvolat <alex@adnab.me> | 2023-10-18 21:53:04 +0200 |
commit | fab4731ad5a4ca26beb1a342ba85eec92014c04b (patch) | |
tree | f405915896f1e5193346d0eb13404b51d0669935 /src | |
parent | d9078f7674c637dd8498ece74ffe6cb7d1e179b9 (diff) | |
download | datagengo-fab4731ad5a4ca26beb1a342ba85eec92014c04b.tar.gz datagengo-fab4731ad5a4ca26beb1a342ba85eec92014c04b.zip |
regenerate extra examples with more diversity
Diffstat (limited to 'src')
-rw-r--r-- | src/main.rs | 70 |
1 files changed, 56 insertions, 14 deletions
diff --git a/src/main.rs b/src/main.rs index 66f1f51..04a9e80 100644 --- a/src/main.rs +++ b/src/main.rs @@ -986,14 +986,23 @@ fn add_vocab(all_batches: &mut [Batch], vocab: &[JlptVocab]) { fn add_extra_examples(all_batches: &mut [Batch], examples: &[Example]) { let mut chars = Charset::default(); + let mut char_seen_count: HashMap<char, usize> = HashMap::new(); + for (i, batch) in all_batches.iter_mut().enumerate() { + println!("---- BATCH #{:03} ----", i); chars = chars.union(&batch.chars); + // Take only examples that: + // - contain kanji of this batch + // - only contain kanji of this or previous batches + // - are not in the batch's main example sentences let candidates = examples .iter() .filter(|x| x.chars.inter_len(&batch.chars) > 0) .filter(|x| x.chars.diff(&chars).len() == 0) .filter(|x| batch.examples.iter().all(|y| y.ja != x.ja)); + + // Take only one candidate sentence for each possible set of represented kanji let mut cand_by_chars = HashMap::new(); for c in candidates { cand_by_chars.insert(c.chars.to_string(), c.clone()); @@ -1002,20 +1011,56 @@ fn add_extra_examples(all_batches: &mut [Batch], examples: &[Example]) { .into_iter() .map(|(_, ex)| ex) .collect::<Vec<_>>(); + + // Sorte candidates in a deterministic random order candidates.sort_by_key(|ex| fasthash::metro::hash64(ex.ja.as_bytes())); batch.extra_examples.clear(); - let mut in_batch = Charset::from_iter(batch.examples.iter().map(|x| x.chars.chars().iter().copied()).flatten()); - let mut in_extra = Charset::default(); + + let mut batch_char_seen_count: HashMap<char, usize> = HashMap::new(); + let mut in_batch = Charset::from_iter( + batch + .examples + .iter() + .map(|x| x.chars.chars().iter().copied()) + .flatten(), + ); while batch.extra_examples.len() < 40 { - let best = candidates.iter().enumerate() - .map(|(i, ex)| (i, ex, ex.chars.diff(&in_batch).len(), ex.chars.diff(&in_extra).len())) - .max_by_key(|(_, _, w1, w2)| (*w1, *w2)); - if let Some((i, ex, w1, w2)) = best { - if w1 > 0 || w2 > 0 || batch.extra_examples.len() < 20 { + let best = candidates + .iter() + .enumerate() + .map(|(i, ex)| { + ( + i, + ex, + ex.chars + .inter(&batch.chars) + .chars() + .iter() + .map(|x| batch_char_seen_count.get(x).copied().unwrap_or_default()) + .min() + .unwrap_or_default(), + ex.chars.diff(&in_batch).len(), + ex.chars + .chars() + .iter() + .map(|x| char_seen_count.get(x).copied().unwrap_or_default()) + .sum::<usize>() as f32 + / ex.chars.len() as f32, + ) + }) + .max_by_key(|(_, _, w1, w2, w3)| (-(*w1 as i64), *w2, -(*w3 * 100_000f32) as i64)); + if let Some((i, ex, w1, w2, w3)) = best { + if w2 > 0 || batch.extra_examples.len() < 20 { + println!("{}\t{}\t{:.2}\t{} - {}", w1, w2, w3, ex.ja, ex.en); batch.extra_examples.push(ex.clone()); in_batch = in_batch.union(&ex.chars); - in_extra = in_extra.union(&ex.chars); + for c in ex.chars.chars().iter() { + *char_seen_count.entry(*c).or_default() += 1; + if batch.chars.chars().contains(c) { + *batch_char_seen_count.entry(*c).or_default() += 1; + } + } candidates.remove(i); continue; } @@ -1023,12 +1068,9 @@ fn add_extra_examples(all_batches: &mut [Batch], examples: &[Example]) { break; } - batch.extra_examples.sort_by_key(|ex| fasthash::metro::hash64(ex.ja.as_bytes())); - - println!("---- BATCH #{:03} ----", i); - for ex in batch.extra_examples.iter() { - println!("{} - {}", ex.ja, ex.en); - } + batch + .extra_examples + .sort_by_key(|ex| fasthash::metro::hash64(ex.ja.as_bytes())); } } |