Redo the base64 encoder in a cleaner way
Use the new hex_encoder/decoder methods to create byte arrays and then encode those arrays into base64.
This commit is contained in:
parent
d0e6c54498
commit
9abe0827bb
1 changed files with 87 additions and 52 deletions
139
src/b64.rs
139
src/b64.rs
|
@ -1,96 +1,131 @@
|
||||||
use hex::{AsHexBytes, HexResult};
|
// base64.rs
|
||||||
|
// Eryn Wells <eryn@erynwells.me>
|
||||||
|
|
||||||
static B64: &'static str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+\\";
|
static B64: &'static str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+\\";
|
||||||
|
|
||||||
pub fn base64(hex: &str) -> Result<String, String> {
|
pub struct Base64Encoder<T> {
|
||||||
let mut out = String::from("");
|
/// Input iterator
|
||||||
let mut num_bits = 0;
|
input: T,
|
||||||
let mut acc: u32 = 0;
|
/// Accumulator. Bits are read into here from the input and shifted out in `next()`.
|
||||||
for (idx, c) in hex.hex_bytes().enumerate() {
|
acc: u32,
|
||||||
match c {
|
/// Number of bits to shift the accumulator for the next output byte.
|
||||||
HexResult::Byte(c) => {
|
shift_bits: i8,
|
||||||
// Accumulate bytes until we have 6 chunks of 4.
|
/// Number of padding characters to emit after the accumulator has been drained.
|
||||||
acc = (acc << 4) + (c as u32);
|
padding: i8,
|
||||||
num_bits += 4;
|
}
|
||||||
if idx % 6 != 5 {
|
|
||||||
continue;
|
impl<T> Base64Encoder<T> {
|
||||||
}
|
pub fn new(input: T) -> Base64Encoder<T> {
|
||||||
|
Base64Encoder {
|
||||||
// Read out 4 chunks of 6.
|
input: input,
|
||||||
for i in (0..4).rev() {
|
acc: 0,
|
||||||
let out_char_idx = ((acc >> (6 * i)) & 0x3F) as usize;
|
shift_bits: -1,
|
||||||
// TODO: I don't like this nth() call.
|
padding: 0,
|
||||||
if let Some(out_char) = B64.chars().nth(out_char_idx) {
|
|
||||||
out.push(out_char);
|
|
||||||
} else {
|
|
||||||
return Err(format!("Couldn't make output char from {}", out_char_idx));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
acc = 0;
|
|
||||||
num_bits = 0;
|
|
||||||
},
|
|
||||||
HexResult::Invalid(c) => {
|
|
||||||
return Err(format!("Invalid input: {}", c));
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if acc != 0 {
|
impl<'a, T> Iterator for Base64Encoder<T> where T: Iterator<Item=u8> + 'a {
|
||||||
// Pad the string if we have bits remaining.
|
type Item = char;
|
||||||
acc = acc << (24 - num_bits);
|
|
||||||
let padding = (24 - num_bits) / 6;
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
for i in (0..4).rev() {
|
if self.shift_bits < 0 && self.padding == 0 {
|
||||||
let out_char_idx = ((acc >> (6 * i)) & 0x3F) as usize;
|
self.get_input_bytes();
|
||||||
if i < padding {
|
}
|
||||||
out.push('=');
|
|
||||||
} else if let Some(out_char) = B64.chars().nth(out_char_idx) {
|
if self.shift_bits >= 0 {
|
||||||
out.push(out_char);
|
let char_index = ((self.acc >> self.shift_bits) & 0x3F) as usize;
|
||||||
|
let out = B64.chars().nth(char_index);
|
||||||
|
println!("out: acc:{:024b}, shift:{:2}, idx:{:08b}->{:?}", self.acc, self.shift_bits, char_index, out);
|
||||||
|
self.shift_bits -= 6;
|
||||||
|
out
|
||||||
|
} else if self.padding > 0 {
|
||||||
|
self.padding -= 1;
|
||||||
|
Some('=')
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Base64Encoder<T> where T: Iterator<Item=u8> + 'a {
|
||||||
|
fn get_input_bytes(&mut self) {
|
||||||
|
let input_bytes = self.take_from_input(3);
|
||||||
|
let num_bits: i8 = input_bytes.len() as i8 * 8;
|
||||||
|
|
||||||
|
if num_bits != 0 {
|
||||||
|
// Shift over a few more bits to make sure the accumulator is divisible by 6.
|
||||||
|
let makeup_shift = (24 - num_bits) % 6;
|
||||||
|
self.acc = input_bytes.into_iter().fold(0, |acc, nxt| (acc << 8) + (nxt as u32)) << makeup_shift;
|
||||||
|
self.shift_bits = (num_bits + makeup_shift) - 6;
|
||||||
|
self.padding = (24 - num_bits) / 6;
|
||||||
|
} else {
|
||||||
|
self.acc = 0;
|
||||||
|
self.shift_bits = -1;
|
||||||
|
self.padding = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("get: acc:{:024b}, shift:{:2}, padding:{}", self.acc, self.shift_bits, self.padding);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn take_from_input(&mut self, n: usize) -> Vec<T::Item> {
|
||||||
|
let mut input_bytes: Vec<T::Item> = Vec::with_capacity(n);
|
||||||
|
for _ in 0..n {
|
||||||
|
if let Some(x) = self.input.next() {
|
||||||
|
input_bytes.push(x);
|
||||||
} else {
|
} else {
|
||||||
return Err(format!("Couldn't make output char from {}", out_char_idx));
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
input_bytes
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Ok(out)
|
pub trait Base64Encodable<T> {
|
||||||
|
fn base64_encoded(self) -> Base64Encoder<T>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Base64Encodable<T> for T where T: Iterator<Item=u8> + 'a {
|
||||||
|
fn base64_encoded(self) -> Base64Encoder<T> { Base64Encoder::new(self) }
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::base64;
|
use hex::*;
|
||||||
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn small_wikipedia_example() {
|
fn small_wikipedia_example() {
|
||||||
let input = "4d616e";
|
let input = "4d616e";
|
||||||
let ex_output = "TWFu";
|
let ex_output = "TWFu";
|
||||||
println!("");
|
let output: String = input.chars().hex_decoded().base64_encoded().collect();
|
||||||
let output = base64(input).expect("Error converting to base64");
|
|
||||||
assert_eq!(output, ex_output);
|
assert_eq!(output, ex_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn one_byte_padding() {
|
fn one_byte_padding() {
|
||||||
|
println!("");
|
||||||
let input = "6f6d";
|
let input = "6f6d";
|
||||||
let ex_output = "b20=";
|
let ex_output = "b20=";
|
||||||
println!("");
|
let output: String = input.chars().hex_decoded().base64_encoded().collect();
|
||||||
let output = base64(input).expect("Error converting to base64");
|
|
||||||
assert_eq!(output, ex_output);
|
assert_eq!(output, ex_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn two_byte_padding() {
|
fn two_byte_padding() {
|
||||||
|
println!("");
|
||||||
let input = "6f";
|
let input = "6f";
|
||||||
let ex_output = "bw==";
|
let ex_output = "bw==";
|
||||||
println!("");
|
let output: String = input.chars().hex_decoded().base64_encoded().collect();
|
||||||
let output = base64(input).expect("Error converting to base64");
|
|
||||||
assert_eq!(output, ex_output);
|
assert_eq!(output, ex_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cryptopals() {
|
fn cryptopals() {
|
||||||
|
println!("");
|
||||||
let input = "49276d206b696c6c696e6720796f757220627261696e206c696b65206120706f69736f6e6f7573206d757368726f6f6d";
|
let input = "49276d206b696c6c696e6720796f757220627261696e206c696b65206120706f69736f6e6f7573206d757368726f6f6d";
|
||||||
let ex_output = "SSdtIGtpbGxpbmcgeW91ciBicmFpbiBsaWtlIGEgcG9pc29ub3VzIG11c2hyb29t";
|
let ex_output = "SSdtIGtpbGxpbmcgeW91ciBicmFpbiBsaWtlIGEgcG9pc29ub3VzIG11c2hyb29t";
|
||||||
println!("");
|
let output: String = input.chars().hex_decoded().base64_encoded().collect();
|
||||||
let output = base64(input).expect("Error converting to base64");
|
|
||||||
assert_eq!(output, ex_output);
|
assert_eq!(output, ex_output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue