Redo the base64 encoder in a cleaner way

Use the new hex_encoder/decoder methods to create byte arrays and then encode
those arrays into base64.
This commit is contained in:
Eryn Wells 2018-03-31 08:28:45 -07:00
parent d0e6c54498
commit 9abe0827bb

View file

@ -1,96 +1,131 @@
use hex::{AsHexBytes, HexResult};
// base64.rs
// Eryn Wells <eryn@erynwells.me>
static B64: &'static str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+\\";
pub fn base64(hex: &str) -> Result<String, String> {
let mut out = String::from("");
let mut num_bits = 0;
let mut acc: u32 = 0;
for (idx, c) in hex.hex_bytes().enumerate() {
match c {
HexResult::Byte(c) => {
// Accumulate bytes until we have 6 chunks of 4.
acc = (acc << 4) + (c as u32);
num_bits += 4;
if idx % 6 != 5 {
continue;
}
pub struct Base64Encoder<T> {
/// Input iterator
input: T,
/// Accumulator. Bits are read into here from the input and shifted out in `next()`.
acc: u32,
/// Number of bits to shift the accumulator for the next output byte.
shift_bits: i8,
/// Number of padding characters to emit after the accumulator has been drained.
padding: i8,
}
// Read out 4 chunks of 6.
for i in (0..4).rev() {
let out_char_idx = ((acc >> (6 * i)) & 0x3F) as usize;
// TODO: I don't like this nth() call.
if let Some(out_char) = B64.chars().nth(out_char_idx) {
out.push(out_char);
} else {
return Err(format!("Couldn't make output char from {}", out_char_idx));
}
}
acc = 0;
num_bits = 0;
},
HexResult::Invalid(c) => {
return Err(format!("Invalid input: {}", c));
},
impl<T> Base64Encoder<T> {
pub fn new(input: T) -> Base64Encoder<T> {
Base64Encoder {
input: input,
acc: 0,
shift_bits: -1,
padding: 0,
}
}
}
if acc != 0 {
// Pad the string if we have bits remaining.
acc = acc << (24 - num_bits);
let padding = (24 - num_bits) / 6;
for i in (0..4).rev() {
let out_char_idx = ((acc >> (6 * i)) & 0x3F) as usize;
if i < padding {
out.push('=');
} else if let Some(out_char) = B64.chars().nth(out_char_idx) {
out.push(out_char);
impl<'a, T> Iterator for Base64Encoder<T> where T: Iterator<Item=u8> + 'a {
type Item = char;
fn next(&mut self) -> Option<Self::Item> {
if self.shift_bits < 0 && self.padding == 0 {
self.get_input_bytes();
}
if self.shift_bits >= 0 {
let char_index = ((self.acc >> self.shift_bits) & 0x3F) as usize;
let out = B64.chars().nth(char_index);
println!("out: acc:{:024b}, shift:{:2}, idx:{:08b}->{:?}", self.acc, self.shift_bits, char_index, out);
self.shift_bits -= 6;
out
} else if self.padding > 0 {
self.padding -= 1;
Some('=')
} else {
None
}
}
}
impl<'a, T> Base64Encoder<T> where T: Iterator<Item=u8> + 'a {
fn get_input_bytes(&mut self) {
let input_bytes = self.take_from_input(3);
let num_bits: i8 = input_bytes.len() as i8 * 8;
if num_bits != 0 {
// Shift over a few more bits to make sure the accumulator is divisible by 6.
let makeup_shift = (24 - num_bits) % 6;
self.acc = input_bytes.into_iter().fold(0, |acc, nxt| (acc << 8) + (nxt as u32)) << makeup_shift;
self.shift_bits = (num_bits + makeup_shift) - 6;
self.padding = (24 - num_bits) / 6;
} else {
self.acc = 0;
self.shift_bits = -1;
self.padding = -1;
}
println!("get: acc:{:024b}, shift:{:2}, padding:{}", self.acc, self.shift_bits, self.padding);
}
fn take_from_input(&mut self, n: usize) -> Vec<T::Item> {
let mut input_bytes: Vec<T::Item> = Vec::with_capacity(n);
for _ in 0..n {
if let Some(x) = self.input.next() {
input_bytes.push(x);
} else {
return Err(format!("Couldn't make output char from {}", out_char_idx));
break;
}
}
input_bytes
}
}
Ok(out)
pub trait Base64Encodable<T> {
fn base64_encoded(self) -> Base64Encoder<T>;
}
impl<'a, T> Base64Encodable<T> for T where T: Iterator<Item=u8> + 'a {
fn base64_encoded(self) -> Base64Encoder<T> { Base64Encoder::new(self) }
}
#[cfg(test)]
mod tests {
use super::base64;
use hex::*;
use super::*;
#[test]
fn small_wikipedia_example() {
let input = "4d616e";
let ex_output = "TWFu";
println!("");
let output = base64(input).expect("Error converting to base64");
let output: String = input.chars().hex_decoded().base64_encoded().collect();
assert_eq!(output, ex_output);
}
#[test]
fn one_byte_padding() {
println!("");
let input = "6f6d";
let ex_output = "b20=";
println!("");
let output = base64(input).expect("Error converting to base64");
let output: String = input.chars().hex_decoded().base64_encoded().collect();
assert_eq!(output, ex_output);
}
#[test]
fn two_byte_padding() {
println!("");
let input = "6f";
let ex_output = "bw==";
println!("");
let output = base64(input).expect("Error converting to base64");
let output: String = input.chars().hex_decoded().base64_encoded().collect();
assert_eq!(output, ex_output);
}
#[test]
fn cryptopals() {
println!("");
let input = "49276d206b696c6c696e6720796f757220627261696e206c696b65206120706f69736f6e6f7573206d757368726f6f6d";
let ex_output = "SSdtIGtpbGxpbmcgeW91ciBicmFpbiBsaWtlIGEgcG9pc29ub3VzIG11c2hyb29t";
println!("");
let output = base64(input).expect("Error converting to base64");
let output: String = input.chars().hex_decoded().base64_encoded().collect();
assert_eq!(output, ex_output);
}
}