aboutsummaryrefslogtreecommitdiff
path: root/src/util/migrate.rs
blob: 1229fd9c0d17e83549656cce4a04521c48e05138 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
use serde::{Deserialize, Serialize};

/// Indicates that this type has an encoding that can be migrated from
/// a previous version upon upgrades of Garage.
pub trait Migrate: Serialize + for<'de> Deserialize<'de> + 'static {
	/// A sequence of bytes to add at the beginning of the serialized
	/// string, to identify that the data is of this version.
	const VERSION_MARKER: &'static [u8] = b"";

	/// The previous version of this data type, from which items of this version
	/// can be migrated.
	type Previous: Migrate;

	/// The migration function that transforms a value decoded in the old format
	/// to an up-to-date value.
	fn migrate(previous: Self::Previous) -> Self;

	/// Decode an encoded version of this type, going through a migration if necessary.
	fn decode(bytes: &[u8]) -> Option<Self> {
		let marker_len = Self::VERSION_MARKER.len();
		if bytes.get(..marker_len) == Some(Self::VERSION_MARKER) {
			if let Ok(value) = rmp_serde::decode::from_read_ref::<_, Self>(&bytes[marker_len..]) {
				return Some(value);
			}
		}

		Self::Previous::decode(bytes).map(Self::migrate)
	}

	/// Encode this type with optionnal version marker
	fn encode(&self) -> Result<Vec<u8>, rmp_serde::encode::Error> {
		let mut wr = Vec::with_capacity(128);
		wr.extend_from_slice(Self::VERSION_MARKER);
		let mut se = rmp_serde::Serializer::new(&mut wr)
			.with_struct_map()
			.with_string_variants();
		self.serialize(&mut se)?;
		Ok(wr)
	}
}

/// Indicates that this type has no previous encoding version to be migrated from.
pub trait InitialFormat: Serialize + for<'de> Deserialize<'de> + 'static {
	/// A sequence of bytes to add at the beginning of the serialized
	/// string, to identify that the data is of this version.
	const VERSION_MARKER: &'static [u8] = b"";
}

impl<T: InitialFormat> Migrate for T {
	const VERSION_MARKER: &'static [u8] = <T as InitialFormat>::VERSION_MARKER;

	type Previous = NoPrevious;

	fn migrate(_previous: Self::Previous) -> Self {
		unreachable!();
	}
}

/// Internal type used by InitialFormat, not meant for general use.
#[derive(Serialize, Deserialize)]
pub enum NoPrevious {}

impl Migrate for NoPrevious {
	type Previous = NoPrevious;

	fn migrate(_previous: Self::Previous) -> Self {
		unreachable!();
	}

	fn decode(_bytes: &[u8]) -> Option<Self> {
		None
	}

	fn encode(&self) -> Result<Vec<u8>, rmp_serde::encode::Error> {
		unreachable!()
	}
}

#[cfg(test)]
mod test {
	use super::*;

	#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
	struct V1 {
		a: usize,
		b: String,
	}
	impl InitialFormat for V1 {}

	#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
	struct V2 {
		a: usize,
		b: Vec<String>,
		c: String,
	}
	impl Migrate for V2 {
		const VERSION_MARKER: &'static [u8] = b"GtestV2";
		type Previous = V1;
		fn migrate(prev: V1) -> V2 {
			V2 {
				a: prev.a,
				b: vec![prev.b],
				c: String::new(),
			}
		}
	}

	#[test]
	fn test_v1() {
		let x = V1 {
			a: 12,
			b: "hello".into(),
		};
		let x_enc = x.encode().unwrap();
		let y = V1::decode(&x_enc).unwrap();
		assert_eq!(x, y);
	}

	#[test]
	fn test_v2() {
		let x = V2 {
			a: 12,
			b: vec!["hello".into(), "world".into()],
			c: "plop".into(),
		};
		let x_enc = x.encode().unwrap();
		assert_eq!(&x_enc[..V2::VERSION_MARKER.len()], V2::VERSION_MARKER);
		let y = V2::decode(&x_enc).unwrap();
		assert_eq!(x, y);
	}

	#[test]
	fn test_migrate() {
		let x = V1 {
			a: 12,
			b: "hello".into(),
		};
		let x_enc = x.encode().unwrap();

		let xx = V1::decode(&x_enc).unwrap();
		assert_eq!(x, xx);

		let y = V2::decode(&x_enc).unwrap();
		assert_eq!(
			y,
			V2 {
				a: 12,
				b: vec!["hello".into()],
				c: "".into(),
			}
		);

		let y_enc = y.encode().unwrap();
		assert_eq!(&y_enc[..V2::VERSION_MARKER.len()], V2::VERSION_MARKER);

		let z = V2::decode(&y_enc).unwrap();
		assert_eq!(y, z);
	}
}