diff options
author | Alex Auvolat <alex@adnab.me> | 2023-11-15 19:02:11 +0100 |
---|---|---|
committer | Alex Auvolat <alex@adnab.me> | 2023-11-15 19:02:11 +0100 |
commit | 56654ce07fb1319a21e6d5f2bdcc0d024c4db398 (patch) | |
tree | 2e58acc5cfd91577a4f05aa85275dc3bb1fc4b53 /src | |
parent | 643f65d17c264ead3726ef7efd4f820d78e4d5b3 (diff) | |
download | datagengo-56654ce07fb1319a21e6d5f2bdcc0d024c4db398.tar.gz datagengo-56654ce07fb1319a21e6d5f2bdcc0d024c4db398.zip |
Regenerate examples with hopefully more variety
Diffstat (limited to 'src')
-rw-r--r-- | src/charset.rs | 3 | ||||
-rw-r--r-- | src/main.rs | 90 |
2 files changed, 61 insertions, 32 deletions
diff --git a/src/charset.rs b/src/charset.rs index 71e1e84..b322324 100644 --- a/src/charset.rs +++ b/src/charset.rs @@ -19,6 +19,9 @@ impl Charset { chars.dedup(); Self(chars) } + pub fn iter(&self) -> impl Iterator<Item = char> + '_ { + self.0.iter().copied() + } pub fn intersects(&self, other: &Self) -> bool { let mut it1 = self.0.iter().peekable(); let mut it2 = other.0.iter().peekable(); diff --git a/src/main.rs b/src/main.rs index 70a7fa3..f78bf81 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1018,51 +1018,65 @@ fn add_extra_examples(all_batches: &mut [Batch], examples: &[Example]) { batch.extra_examples.clear(); let mut batch_char_seen_count: HashMap<char, usize> = HashMap::new(); - let mut in_batch = Charset::from_iter( - batch - .examples - .iter() - .map(|x| x.chars.chars().iter().copied()) - .flatten(), - ); + let mut in_batch = + Charset::from_iter(batch.examples.iter().map(|x| x.chars.iter()).flatten()); let mut in_batch_extra = Charset::default(); + while batch.extra_examples.len() < 40 { + let batch_min_seen = batch + .chars + .iter() + .map(|x| batch_char_seen_count.get(&x).copied().unwrap_or(0)) + .min() + .unwrap(); + // Target chars: chars of the batch that have the less examples + let c0 = + Charset::from_iter(batch.chars.iter().filter(|x| { + batch_char_seen_count.get(x).copied().unwrap_or(0) == batch_min_seen + })); + // Target chars: chars that have been seen less than cnt times + let fc = |cnt| { + Charset::from_iter( + chars + .iter() + .filter(|x| char_seen_count.get(x).copied().unwrap_or(0) <= cnt), + ) + }; + let c1 = fc(1); + let c2 = fc(2); + let c4 = fc(4); + let c7 = fc(7); + let best = candidates .iter() .enumerate() .map(|(i, ex)| { - ( - i, - ex, - ex.chars - .inter(&batch.chars) - .chars() - .iter() - .map(|x| batch_char_seen_count.get(x).copied().unwrap_or_default()) - .min() - .unwrap_or_default(), - ex.chars.diff(&in_batch).len(), - ex.chars - .chars() - .iter() - .map(|x| char_seen_count.get(x).copied().unwrap_or_default()) - .sum::<usize>() as f32 - / ex.chars.len() as f32, - ) + let weight = ( + ex.chars.inter_len(&c0), + ex.chars.inter_len(&c1), + ex.chars.inter_len(&c2), + ex.chars.inter_len(&c4), + ex.chars.inter_len(&c7), + ex.chars.diff(&in_batch_extra).len(), + ); + (i, ex, weight) }) - .max_by_key(|(_, _, w1, w2, w3)| (-(*w1 as i64), *w2, -(*w3 * 100_000f32) as i64)); - if let Some((i, ex, w1, w2, w3)) = best { + .max_by_key(|(_, _, w)| *w); + if let Some((i, ex, w)) = best { if ex.chars.diff(&in_batch_extra).len() > 0 || batch.extra_examples.len() < 20 { - println!("{}\t{}\t{:.2}\t{} - {}", w1, w2, w3, ex.ja, ex.en); + println!("{:?}\t{} - {}", w, ex.ja, ex.en); + batch.extra_examples.push(ex.clone()); in_batch = in_batch.union(&ex.chars); in_batch_extra = in_batch_extra.union(&ex.chars); - for c in ex.chars.chars().iter() { - *char_seen_count.entry(*c).or_default() += 1; - if batch.chars.chars().contains(c) { - *batch_char_seen_count.entry(*c).or_default() += 1; + + for c in ex.chars.iter() { + *char_seen_count.entry(c).or_default() += 1; + if batch.chars.contains(c) { + *batch_char_seen_count.entry(c).or_default() += 1; } } + candidates.remove(i); continue; } @@ -1073,6 +1087,18 @@ fn add_extra_examples(all_batches: &mut [Batch], examples: &[Example]) { batch .extra_examples .sort_by_key(|ex| fasthash::metro::hash64(ex.ja.as_bytes())); + + for i in 1..10 { + println!( + "Seen {:02}: {}", + i, + char_seen_count.iter().filter(|(_, v)| **v == i).count() + ); + } + println!( + "Seen more: {}", + char_seen_count.iter().filter(|(_, v)| **v >= 10).count() + ); } } |