diff options
author | Alex Auvolat <alex@adnab.me> | 2023-07-21 16:58:36 +0200 |
---|---|---|
committer | Alex Auvolat <alex@adnab.me> | 2023-07-21 16:58:36 +0200 |
commit | e220b38123fcecbf4448826f3f0ca2098c89181f (patch) | |
tree | 2a91b607a82afe5b24d8d03a49f083f501e1aea5 | |
parent | 13997439f8f1440b56c1e7dd449e3444aad28197 (diff) | |
download | datagengo-e220b38123fcecbf4448826f3f0ca2098c89181f.tar.gz datagengo-e220b38123fcecbf4448826f3f0ca2098c89181f.zip |
first iteration of batch generation algo
-rw-r--r-- | src/main.rs | 191 |
1 files changed, 175 insertions, 16 deletions
diff --git a/src/main.rs b/src/main.rs index f99d236..d1efece 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,7 @@ use std::fs; use std::cmp::Ordering; use std::io::{self, BufRead}; -use anyhow::Result; +use anyhow::{anyhow, Result}; use structopt::StructOpt; #[derive(Debug, StructOpt)] @@ -30,13 +30,20 @@ fn main() { } }, Cmd::New => { - let kanji_levels = read_kanji_levels().expect("error"); + let kanji_levels = read_kanji_levels().expect("read_kanji_levels"); let all_kanji = Charset::new(kanji_levels.iter() .map(|(_, x)| x.to_string()) .collect::<Vec<_>>() .join("")); - let ex = read_examples(&all_kanji).expect("error"); - println!("{:#?}", &ex[..10]); + let kanji_levels = kanji_levels.into_iter() + .map(|(l, x)| (l, Charset::new(x))) + .collect::<Vec<_>>(); + let ex = read_examples(&all_kanji).expect("read_examples"); + println!("{:#?}", ex.iter().take(10).collect::<Vec<_>>()); + let batch1 = gen_batch(&[], &kanji_levels, &ex).expect("gen_batch"); + println!("{:#?}", batch1); + let batch2 = gen_batch(&[batch1], &kanji_levels, &ex).expect("gen_batch"); + println!("{:#?}", batch2); } } } @@ -99,7 +106,7 @@ fn read_examples(all_kanji: &Charset) -> Result<Vec<Example>> { let mut ret = Vec::new(); let mut a = "".to_string(); - for line in io::BufReader::new(file).lines() { + for (i, line) in io::BufReader::new(file).lines().enumerate() { let line = line?; if line.starts_with("A:") { a = line; @@ -114,7 +121,7 @@ fn read_examples(all_kanji: &Charset) -> Result<Vec<Example>> { en: eng.to_string(), expl: b.to_string(), id: Some(id.to_string()), - chars: Charset::new(ja).inter_chars(all_kanji), + chars: Charset::new(ja).inter(all_kanji), }); } else { ret.push(Example { @@ -122,21 +129,76 @@ fn read_examples(all_kanji: &Charset) -> Result<Vec<Example>> { en: eng.to_string(), expl: b.to_string(), id: None, - chars: Charset::new(ja).inter_chars(all_kanji), + chars: Charset::new(ja).inter(all_kanji), }); } } } } - if ret.len() > 100 { - break; + if i % 10000 == 0 { + eprintln!("read examples: {}/300", i/1000); } } Ok(ret) } -#[derive(Debug)] +fn gen_batch(previous: &[Batch], kanji_levels: &[(String, Charset)], examples: &[Example]) -> Result<Batch> { + let prev_chars = Charset::from_iter(previous.iter() + .map(|x| x.chars.chars().iter().copied()) + .flatten()); + + let (mut target_i, target_level, mut target_chars) = kanji_levels.iter().enumerate() + .map(|(i, (l, c))| (i, l, c.diff(&prev_chars))) + .find(|(_, _, c)| !c.is_empty()) + .ok_or(anyhow!("no more batches to make!"))?; + + let chars_4 = kanji_levels[..target_i].iter().rev().next() + .map(|(_, c)| c.clone()).unwrap_or(Charset::new("")); + + let chars_2 = kanji_levels[..target_i].iter().rev().skip(1).next() + .map(|(_, c)| c.clone()).unwrap_or(Charset::new("")); + + let chars_bad = Charset::from_iter(kanji_levels.iter().skip(target_i+1) + .map(|(_, c)| c.chars().iter().copied()) + .flatten()); + + let mut batch = Batch { + level: target_level.to_string(), + chars: Charset::new(""), + examples: Vec::new(), + }; + + eprintln!("Target (val=10) : {}", target_chars.to_string()); + eprintln!("Prev1 (val=4) : {}", chars_4.to_string()); + eprintln!("Prev2 (val=2) : {}", chars_2.to_string()); + eprintln!("Bad (val=-10): {}", chars_bad.to_string()); + + let batch_len = 20; + while batch.chars.len() < batch_len && !target_chars.is_empty() { + if let Some((ex, _)) = examples.iter() + .map(|ex| (ex, ex.chars.inter_len(&target_chars))) + .filter(|(_, ex_tgt_inter)| *ex_tgt_inter <= 4) + .filter(|(_, ex_tgt_inter)| *ex_tgt_inter + batch.chars.len() <= batch_len + 1) + .max_by_key(|(ex, ex_tgt_inter)| + 10i32 * *ex_tgt_inter as i32 + + 4i32 * ex.chars.inter_len(&chars_4) as i32 + + 2i32 * ex.chars.inter_len(&chars_2) as i32 + - 40i32 * ex.chars.inter_len(&chars_bad) as i32) { + println!("add: {:?} (bad: {})", ex, ex.chars.inter(&chars_bad).to_string()); + batch.chars = batch.chars.union(&ex.chars.inter(&target_chars)); + target_chars = target_chars.diff(&ex.chars); + batch.examples.push(ex.clone()); + } else { + eprintln!("could not find sentence that doesn't add myriads too many characters, stopping batch now"); + break; + } + } + + Ok(batch) +} + +#[derive(Debug, Clone)] struct Example { ja: String, en: String, @@ -145,7 +207,14 @@ struct Example { chars: Charset, } -#[derive(Debug)] +#[derive(Debug, Clone)] +struct Batch { + level: String, + chars: Charset, + examples: Vec<Example>, +} + +#[derive(Debug, Eq, PartialEq, Hash, Clone)] struct Charset(Vec<char>); impl Charset { @@ -155,6 +224,12 @@ impl Charset { chars.dedup(); Self(chars) } + fn from_iter<S: IntoIterator<Item=char>>(s: S) -> Self { + let mut chars = s.into_iter().collect::<Vec<_>>(); + chars.sort(); + chars.dedup(); + Self(chars) + } fn intersects(&self, other: &Self) -> bool { let mut it1 = self.0.iter().peekable(); let mut it2 = other.0.iter().peekable(); @@ -167,7 +242,7 @@ impl Charset { } false } - fn count_inter(&self, other: &Self) -> usize { + fn inter_len(&self, other: &Self) -> usize { let mut it1 = self.0.iter().peekable(); let mut it2 = other.0.iter().peekable(); let mut ret = 0; @@ -188,7 +263,7 @@ impl Charset { } ret } - fn inter_chars(&self, other: &Self) -> Charset { + fn inter(&self, other: &Self) -> Charset { let mut it1 = self.0.iter().peekable(); let mut it2 = other.0.iter().peekable(); let mut ret = Vec::new(); @@ -209,9 +284,77 @@ impl Charset { } Self(ret) } + fn union(&self, other: &Self) -> Charset { + let mut it1 = self.0.iter().peekable(); + let mut it2 = other.0.iter().peekable(); + let mut ret = Vec::new(); + while let (Some(c1), Some(c2)) = (it1.peek(), it2.peek()) { + match c1.cmp(c2) { + Ordering::Equal => { + ret.push(**c1); + it1.next(); + it2.next(); + } + Ordering::Less => { + ret.push(**c1); + it1.next(); + } + Ordering::Greater => { + ret.push(**c2); + it2.next(); + } + }; + } + while let Some(c) = it1.peek() { + ret.push(**c); + it1.next(); + } + while let Some(c) = it2.peek() { + ret.push(**c); + it2.next(); + } + Self(ret) + } + fn diff(&self, other: &Self) -> Charset { + let mut it1 = self.0.iter().peekable(); + let mut it2 = other.0.iter().peekable(); + let mut ret = Vec::new(); + while let (Some(c1), Some(c2)) = (it1.peek(), it2.peek()) { + match c1.cmp(c2) { + Ordering::Equal => { + it1.next(); + it2.next(); + } + Ordering::Less => { + ret.push(**c1); + it1.next(); + } + Ordering::Greater => { + it2.next(); + } + }; + } + while let Some(c) = it1.peek() { + ret.push(**c); + it1.next(); + } + Self(ret) + } + fn len(&self) -> usize { + self.0.len() + } + fn is_empty(&self) -> bool { + self.0.is_empty() + } fn chars(&self) -> &[char] { &self.0 } + fn contains(&self, c: char) -> bool { + self.0.binary_search(&c).is_ok() + } + fn to_string(&self) -> String { + self.0.iter().collect::<String>() + } } #[cfg(test)] @@ -228,8 +371,24 @@ mod test { assert!(c1.intersects(&c3)); assert!(c2.intersects(&c3)); - assert_eq!(c1.count_inter(&c2), 0); - assert_eq!(c1.count_inter(&c3), 2); - assert_eq!(c2.count_inter(&c3), 2); + assert_eq!(c1.inter_len(&c2), 0); + assert_eq!(c1.inter_len(&c3), 2); + assert_eq!(c2.inter_len(&c3), 2); + + assert_eq!(c1.inter(&c2), Charset::new("")); + assert_eq!(c1.inter(&c3), Charset::new("er")); + assert_eq!(c2.inter(&c3), Charset::new("od")); + + assert_eq!(c1.union(&c2), Charset::new("azertyuiopqsdf")); + assert_eq!(c1.union(&c3), Charset::new("azertyhello, world")); + assert_eq!(c2.union(&c3), Charset::new("uiopqsdfhello, world")); + + assert_eq!(c1.diff(&c2), Charset::new("azerty")); + assert_eq!(c1.diff(&c3), Charset::new("azty")); + assert_eq!(c2.diff(&c3), Charset::new("uipqsf")); + + assert_eq!(c2.diff(&c1), Charset::new("uiopqsdf")); + assert_eq!(c3.diff(&c1), Charset::new("hllo, wold")); + assert_eq!(c3.diff(&c2), Charset::new("hell, wrl")); } } |