aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2023-11-15 19:02:11 +0100
committerAlex Auvolat <alex@adnab.me>2023-11-15 19:02:11 +0100
commit56654ce07fb1319a21e6d5f2bdcc0d024c4db398 (patch)
tree2e58acc5cfd91577a4f05aa85275dc3bb1fc4b53 /src
parent643f65d17c264ead3726ef7efd4f820d78e4d5b3 (diff)
downloaddatagengo-56654ce07fb1319a21e6d5f2bdcc0d024c4db398.tar.gz
datagengo-56654ce07fb1319a21e6d5f2bdcc0d024c4db398.zip
Regenerate examples with hopefully more variety
Diffstat (limited to 'src')
-rw-r--r--src/charset.rs3
-rw-r--r--src/main.rs90
2 files changed, 61 insertions, 32 deletions
diff --git a/src/charset.rs b/src/charset.rs
index 71e1e84..b322324 100644
--- a/src/charset.rs
+++ b/src/charset.rs
@@ -19,6 +19,9 @@ impl Charset {
chars.dedup();
Self(chars)
}
+ pub fn iter(&self) -> impl Iterator<Item = char> + '_ {
+ self.0.iter().copied()
+ }
pub fn intersects(&self, other: &Self) -> bool {
let mut it1 = self.0.iter().peekable();
let mut it2 = other.0.iter().peekable();
diff --git a/src/main.rs b/src/main.rs
index 70a7fa3..f78bf81 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1018,51 +1018,65 @@ fn add_extra_examples(all_batches: &mut [Batch], examples: &[Example]) {
batch.extra_examples.clear();
let mut batch_char_seen_count: HashMap<char, usize> = HashMap::new();
- let mut in_batch = Charset::from_iter(
- batch
- .examples
- .iter()
- .map(|x| x.chars.chars().iter().copied())
- .flatten(),
- );
+ let mut in_batch =
+ Charset::from_iter(batch.examples.iter().map(|x| x.chars.iter()).flatten());
let mut in_batch_extra = Charset::default();
+
while batch.extra_examples.len() < 40 {
+ let batch_min_seen = batch
+ .chars
+ .iter()
+ .map(|x| batch_char_seen_count.get(&x).copied().unwrap_or(0))
+ .min()
+ .unwrap();
+ // Target chars: chars of the batch that have the less examples
+ let c0 =
+ Charset::from_iter(batch.chars.iter().filter(|x| {
+ batch_char_seen_count.get(x).copied().unwrap_or(0) == batch_min_seen
+ }));
+ // Target chars: chars that have been seen less than cnt times
+ let fc = |cnt| {
+ Charset::from_iter(
+ chars
+ .iter()
+ .filter(|x| char_seen_count.get(x).copied().unwrap_or(0) <= cnt),
+ )
+ };
+ let c1 = fc(1);
+ let c2 = fc(2);
+ let c4 = fc(4);
+ let c7 = fc(7);
+
let best = candidates
.iter()
.enumerate()
.map(|(i, ex)| {
- (
- i,
- ex,
- ex.chars
- .inter(&batch.chars)
- .chars()
- .iter()
- .map(|x| batch_char_seen_count.get(x).copied().unwrap_or_default())
- .min()
- .unwrap_or_default(),
- ex.chars.diff(&in_batch).len(),
- ex.chars
- .chars()
- .iter()
- .map(|x| char_seen_count.get(x).copied().unwrap_or_default())
- .sum::<usize>() as f32
- / ex.chars.len() as f32,
- )
+ let weight = (
+ ex.chars.inter_len(&c0),
+ ex.chars.inter_len(&c1),
+ ex.chars.inter_len(&c2),
+ ex.chars.inter_len(&c4),
+ ex.chars.inter_len(&c7),
+ ex.chars.diff(&in_batch_extra).len(),
+ );
+ (i, ex, weight)
})
- .max_by_key(|(_, _, w1, w2, w3)| (-(*w1 as i64), *w2, -(*w3 * 100_000f32) as i64));
- if let Some((i, ex, w1, w2, w3)) = best {
+ .max_by_key(|(_, _, w)| *w);
+ if let Some((i, ex, w)) = best {
if ex.chars.diff(&in_batch_extra).len() > 0 || batch.extra_examples.len() < 20 {
- println!("{}\t{}\t{:.2}\t{} - {}", w1, w2, w3, ex.ja, ex.en);
+ println!("{:?}\t{} - {}", w, ex.ja, ex.en);
+
batch.extra_examples.push(ex.clone());
in_batch = in_batch.union(&ex.chars);
in_batch_extra = in_batch_extra.union(&ex.chars);
- for c in ex.chars.chars().iter() {
- *char_seen_count.entry(*c).or_default() += 1;
- if batch.chars.chars().contains(c) {
- *batch_char_seen_count.entry(*c).or_default() += 1;
+
+ for c in ex.chars.iter() {
+ *char_seen_count.entry(c).or_default() += 1;
+ if batch.chars.contains(c) {
+ *batch_char_seen_count.entry(c).or_default() += 1;
}
}
+
candidates.remove(i);
continue;
}
@@ -1073,6 +1087,18 @@ fn add_extra_examples(all_batches: &mut [Batch], examples: &[Example]) {
batch
.extra_examples
.sort_by_key(|ex| fasthash::metro::hash64(ex.ja.as_bytes()));
+
+ for i in 1..10 {
+ println!(
+ "Seen {:02}: {}",
+ i,
+ char_seen_count.iter().filter(|(_, v)| **v == i).count()
+ );
+ }
+ println!(
+ "Seen more: {}",
+ char_seen_count.iter().filter(|(_, v)| **v >= 10).count()
+ );
}
}