From 9abe0827bbe9627a5d0f6dbff47a08000f8b01a3 Mon Sep 17 00:00:00 2001 From: Eryn Wells Date: Sat, 31 Mar 2018 08:28:45 -0700 Subject: [PATCH] Redo the base64 encoder in a cleaner way Use the new hex_encoder/decoder methods to create byte arrays and then encode those arrays into base64. --- src/b64.rs | 139 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 87 insertions(+), 52 deletions(-) diff --git a/src/b64.rs b/src/b64.rs index 9059ecd..0d2544d 100644 --- a/src/b64.rs +++ b/src/b64.rs @@ -1,96 +1,131 @@ -use hex::{AsHexBytes, HexResult}; +// base64.rs +// Eryn Wells static B64: &'static str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+\\"; -pub fn base64(hex: &str) -> Result { - let mut out = String::from(""); - let mut num_bits = 0; - let mut acc: u32 = 0; - for (idx, c) in hex.hex_bytes().enumerate() { - match c { - HexResult::Byte(c) => { - // Accumulate bytes until we have 6 chunks of 4. - acc = (acc << 4) + (c as u32); - num_bits += 4; - if idx % 6 != 5 { - continue; - } - - // Read out 4 chunks of 6. - for i in (0..4).rev() { - let out_char_idx = ((acc >> (6 * i)) & 0x3F) as usize; - // TODO: I don't like this nth() call. - if let Some(out_char) = B64.chars().nth(out_char_idx) { - out.push(out_char); - } else { - return Err(format!("Couldn't make output char from {}", out_char_idx)); - } - } - acc = 0; - num_bits = 0; - }, - HexResult::Invalid(c) => { - return Err(format!("Invalid input: {}", c)); - }, +pub struct Base64Encoder { + /// Input iterator + input: T, + /// Accumulator. Bits are read into here from the input and shifted out in `next()`. + acc: u32, + /// Number of bits to shift the accumulator for the next output byte. + shift_bits: i8, + /// Number of padding characters to emit after the accumulator has been drained. + padding: i8, +} + +impl Base64Encoder { + pub fn new(input: T) -> Base64Encoder { + Base64Encoder { + input: input, + acc: 0, + shift_bits: -1, + padding: 0, } } +} - if acc != 0 { - // Pad the string if we have bits remaining. - acc = acc << (24 - num_bits); - let padding = (24 - num_bits) / 6; - for i in (0..4).rev() { - let out_char_idx = ((acc >> (6 * i)) & 0x3F) as usize; - if i < padding { - out.push('='); - } else if let Some(out_char) = B64.chars().nth(out_char_idx) { - out.push(out_char); +impl<'a, T> Iterator for Base64Encoder where T: Iterator + 'a { + type Item = char; + + fn next(&mut self) -> Option { + if self.shift_bits < 0 && self.padding == 0 { + self.get_input_bytes(); + } + + if self.shift_bits >= 0 { + let char_index = ((self.acc >> self.shift_bits) & 0x3F) as usize; + let out = B64.chars().nth(char_index); + println!("out: acc:{:024b}, shift:{:2}, idx:{:08b}->{:?}", self.acc, self.shift_bits, char_index, out); + self.shift_bits -= 6; + out + } else if self.padding > 0 { + self.padding -= 1; + Some('=') + } else { + None + } + } +} + +impl<'a, T> Base64Encoder where T: Iterator + 'a { + fn get_input_bytes(&mut self) { + let input_bytes = self.take_from_input(3); + let num_bits: i8 = input_bytes.len() as i8 * 8; + + if num_bits != 0 { + // Shift over a few more bits to make sure the accumulator is divisible by 6. + let makeup_shift = (24 - num_bits) % 6; + self.acc = input_bytes.into_iter().fold(0, |acc, nxt| (acc << 8) + (nxt as u32)) << makeup_shift; + self.shift_bits = (num_bits + makeup_shift) - 6; + self.padding = (24 - num_bits) / 6; + } else { + self.acc = 0; + self.shift_bits = -1; + self.padding = -1; + } + + println!("get: acc:{:024b}, shift:{:2}, padding:{}", self.acc, self.shift_bits, self.padding); + } + + fn take_from_input(&mut self, n: usize) -> Vec { + let mut input_bytes: Vec = Vec::with_capacity(n); + for _ in 0..n { + if let Some(x) = self.input.next() { + input_bytes.push(x); } else { - return Err(format!("Couldn't make output char from {}", out_char_idx)); + break; } } + input_bytes } +} - Ok(out) +pub trait Base64Encodable { + fn base64_encoded(self) -> Base64Encoder; +} + +impl<'a, T> Base64Encodable for T where T: Iterator + 'a { + fn base64_encoded(self) -> Base64Encoder { Base64Encoder::new(self) } } #[cfg(test)] mod tests { - use super::base64; + use hex::*; + use super::*; #[test] fn small_wikipedia_example() { let input = "4d616e"; let ex_output = "TWFu"; - println!(""); - let output = base64(input).expect("Error converting to base64"); + let output: String = input.chars().hex_decoded().base64_encoded().collect(); assert_eq!(output, ex_output); } #[test] fn one_byte_padding() { + println!(""); let input = "6f6d"; let ex_output = "b20="; - println!(""); - let output = base64(input).expect("Error converting to base64"); + let output: String = input.chars().hex_decoded().base64_encoded().collect(); assert_eq!(output, ex_output); } #[test] fn two_byte_padding() { + println!(""); let input = "6f"; let ex_output = "bw=="; - println!(""); - let output = base64(input).expect("Error converting to base64"); + let output: String = input.chars().hex_decoded().base64_encoded().collect(); assert_eq!(output, ex_output); } #[test] fn cryptopals() { + println!(""); let input = "49276d206b696c6c696e6720796f757220627261696e206c696b65206120706f69736f6e6f7573206d757368726f6f6d"; let ex_output = "SSdtIGtpbGxpbmcgeW91ciBicmFpbiBsaWtlIGEgcG9pc29ub3VzIG11c2hyb29t"; - println!(""); - let output = base64(input).expect("Error converting to base64"); + let output: String = input.chars().hex_decoded().base64_encoded().collect(); assert_eq!(output, ex_output); } }