Redo the base64 encoder in a cleaner way

Use the new hex_encoder/decoder methods to create byte arrays and then encode
those arrays into base64.
This commit is contained in:
Eryn Wells 2018-03-31 08:28:45 -07:00
parent d0e6c54498
commit 9abe0827bb

View file

@ -1,96 +1,131 @@
use hex::{AsHexBytes, HexResult}; // base64.rs
// Eryn Wells <eryn@erynwells.me>
static B64: &'static str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+\\"; static B64: &'static str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+\\";
pub fn base64(hex: &str) -> Result<String, String> { pub struct Base64Encoder<T> {
let mut out = String::from(""); /// Input iterator
let mut num_bits = 0; input: T,
let mut acc: u32 = 0; /// Accumulator. Bits are read into here from the input and shifted out in `next()`.
for (idx, c) in hex.hex_bytes().enumerate() { acc: u32,
match c { /// Number of bits to shift the accumulator for the next output byte.
HexResult::Byte(c) => { shift_bits: i8,
// Accumulate bytes until we have 6 chunks of 4. /// Number of padding characters to emit after the accumulator has been drained.
acc = (acc << 4) + (c as u32); padding: i8,
num_bits += 4; }
if idx % 6 != 5 {
continue; impl<T> Base64Encoder<T> {
} pub fn new(input: T) -> Base64Encoder<T> {
Base64Encoder {
// Read out 4 chunks of 6. input: input,
for i in (0..4).rev() { acc: 0,
let out_char_idx = ((acc >> (6 * i)) & 0x3F) as usize; shift_bits: -1,
// TODO: I don't like this nth() call. padding: 0,
if let Some(out_char) = B64.chars().nth(out_char_idx) {
out.push(out_char);
} else {
return Err(format!("Couldn't make output char from {}", out_char_idx));
}
}
acc = 0;
num_bits = 0;
},
HexResult::Invalid(c) => {
return Err(format!("Invalid input: {}", c));
},
} }
} }
}
if acc != 0 { impl<'a, T> Iterator for Base64Encoder<T> where T: Iterator<Item=u8> + 'a {
// Pad the string if we have bits remaining. type Item = char;
acc = acc << (24 - num_bits);
let padding = (24 - num_bits) / 6; fn next(&mut self) -> Option<Self::Item> {
for i in (0..4).rev() { if self.shift_bits < 0 && self.padding == 0 {
let out_char_idx = ((acc >> (6 * i)) & 0x3F) as usize; self.get_input_bytes();
if i < padding { }
out.push('=');
} else if let Some(out_char) = B64.chars().nth(out_char_idx) { if self.shift_bits >= 0 {
out.push(out_char); let char_index = ((self.acc >> self.shift_bits) & 0x3F) as usize;
let out = B64.chars().nth(char_index);
println!("out: acc:{:024b}, shift:{:2}, idx:{:08b}->{:?}", self.acc, self.shift_bits, char_index, out);
self.shift_bits -= 6;
out
} else if self.padding > 0 {
self.padding -= 1;
Some('=')
} else {
None
}
}
}
impl<'a, T> Base64Encoder<T> where T: Iterator<Item=u8> + 'a {
fn get_input_bytes(&mut self) {
let input_bytes = self.take_from_input(3);
let num_bits: i8 = input_bytes.len() as i8 * 8;
if num_bits != 0 {
// Shift over a few more bits to make sure the accumulator is divisible by 6.
let makeup_shift = (24 - num_bits) % 6;
self.acc = input_bytes.into_iter().fold(0, |acc, nxt| (acc << 8) + (nxt as u32)) << makeup_shift;
self.shift_bits = (num_bits + makeup_shift) - 6;
self.padding = (24 - num_bits) / 6;
} else {
self.acc = 0;
self.shift_bits = -1;
self.padding = -1;
}
println!("get: acc:{:024b}, shift:{:2}, padding:{}", self.acc, self.shift_bits, self.padding);
}
fn take_from_input(&mut self, n: usize) -> Vec<T::Item> {
let mut input_bytes: Vec<T::Item> = Vec::with_capacity(n);
for _ in 0..n {
if let Some(x) = self.input.next() {
input_bytes.push(x);
} else { } else {
return Err(format!("Couldn't make output char from {}", out_char_idx)); break;
} }
} }
input_bytes
} }
}
Ok(out) pub trait Base64Encodable<T> {
fn base64_encoded(self) -> Base64Encoder<T>;
}
impl<'a, T> Base64Encodable<T> for T where T: Iterator<Item=u8> + 'a {
fn base64_encoded(self) -> Base64Encoder<T> { Base64Encoder::new(self) }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::base64; use hex::*;
use super::*;
#[test] #[test]
fn small_wikipedia_example() { fn small_wikipedia_example() {
let input = "4d616e"; let input = "4d616e";
let ex_output = "TWFu"; let ex_output = "TWFu";
println!(""); let output: String = input.chars().hex_decoded().base64_encoded().collect();
let output = base64(input).expect("Error converting to base64");
assert_eq!(output, ex_output); assert_eq!(output, ex_output);
} }
#[test] #[test]
fn one_byte_padding() { fn one_byte_padding() {
println!("");
let input = "6f6d"; let input = "6f6d";
let ex_output = "b20="; let ex_output = "b20=";
println!(""); let output: String = input.chars().hex_decoded().base64_encoded().collect();
let output = base64(input).expect("Error converting to base64");
assert_eq!(output, ex_output); assert_eq!(output, ex_output);
} }
#[test] #[test]
fn two_byte_padding() { fn two_byte_padding() {
println!("");
let input = "6f"; let input = "6f";
let ex_output = "bw=="; let ex_output = "bw==";
println!(""); let output: String = input.chars().hex_decoded().base64_encoded().collect();
let output = base64(input).expect("Error converting to base64");
assert_eq!(output, ex_output); assert_eq!(output, ex_output);
} }
#[test] #[test]
fn cryptopals() { fn cryptopals() {
println!("");
let input = "49276d206b696c6c696e6720796f757220627261696e206c696b65206120706f69736f6e6f7573206d757368726f6f6d"; let input = "49276d206b696c6c696e6720796f757220627261696e206c696b65206120706f69736f6e6f7573206d757368726f6f6d";
let ex_output = "SSdtIGtpbGxpbmcgeW91ciBicmFpbiBsaWtlIGEgcG9pc29ub3VzIG11c2hyb29t"; let ex_output = "SSdtIGtpbGxpbmcgeW91ciBicmFpbiBsaWtlIGEgcG9pc29ub3VzIG11c2hyb29t";
println!(""); let output: String = input.chars().hex_decoded().base64_encoded().collect();
let output = base64(input).expect("Error converting to base64");
assert_eq!(output, ex_output); assert_eq!(output, ex_output);
} }
} }