MVP implementation of encoder/decoder

This commit is contained in:
Maximilian Friedersdorff 2020-07-18 22:31:41 +01:00
parent 79054dec11
commit 53a825a780

165
mnemonic_key.py Normal file → Executable file
View file

@ -1,41 +1,148 @@
#!/usr/bin/env python3
import hashlib
import math
import sys
import argparse
def access_bit(data, num):
base = int(num // 8)
shift = int(num % 8)
return (data[base] & (1 << shift)) >> shift
BITS_PER_WORD = 11
with open("./input.gpg", "rb") as f:
bs = bytearray(f.read())
with open("./english.txt", "r") as wordlist:
words = [word.strip() for word in wordlist.readlines() if word.strip()]
digest = hashlib.sha256(bs).digest()
def parse_args():
parser = argparse.ArgumentParser("Encode and decode files as a mnemonic")
parser.add_argument("wordlist", type=str, help="The wordlist to use")
parser.add_argument("--decode", action="store_true")
parser.add_argument(
"--input", type=str, help="The input file when encoding"
)
parser.add_argument(
"--output", type=str, help="The file to write to when decoding"
)
parser.add_argument(
"--length", type=int, help="Length in bytes of the decoded output"
)
return parser.parse_args()
bits = [access_bit(bs, i) for i in range(len(bs) * 8)]
checksum_bits = [access_bit(digest, i) for i in range(len(digest) * 8)]
def bits_to_int(bits):
"""Convert passed bits into int
n_bits = len(bits)
nearest_mulitple_of_11 = math.floor((n_bits/11) + 1) * 11
bits_missing = nearest_mulitple_of_11 - n_bits
bits += checksum_bits[0:bits_missing]
Least significant bit first
"""
b = 0
for i, bit in enumerate(bits):
b += bit << i
return b
mnemonic = []
for i in range(0, len(bits), 11):
word_bits = bits[i:i+11]
word_int = 0
for j, bit in enumerate(word_bits):
word_int += bit << j
word = words[word_int] + " "
mnemonic.append(word[0:10])
def byte_to_bits(byte):
"""Convert byte into bit array
mnemonic += [""] * 5
mnemonics = [mnemonic[i:i+5] for i in range(0, len(mnemonic), 5)]
for m in mnemonics:
print("".join(m))
Least significant bit first
"""
return [byte >> i & 1 for i in range(8)]
def create_mnemonic(bites, words, bits_per_word=BITS_PER_WORD):
"""Create mnemonic from bytes
Create mnemonic phrase from an input byte array. Each byte
is convert into a bit array (least significant bit first) and
all such bit arrays are concatenated in the order of the input
bytes. BITS_PER_WORD many bits are consumed from the beginning
of the array and converted into an integer (least significant
bit first) which is used as an index to look up a word in the
given wordlist. A list of so looked up words is returned.
If necessary, the concatenated bit array is padded with the
beginning bits of the sha256 hash of the input byte array
to get to the next multiple of the word size.
:param bites: The bytes to convert.
:param words: The word list to use, must have 2**n many words
:param bits_per_word: The number of bits to consume per word. The
word list should be 2**bits_per_word long
:retrun: Mnemonic phrase
"""
digest = hashlib.sha256(bites).digest()
bits = []
for b in bites:
bits += byte_to_bits(b)
checksum_bits = []
for b in digest:
checksum_bits += byte_to_bits(b)
n_bits = len(bits)
smallest_n_bits = math.floor((n_bits/BITS_PER_WORD) + 1) * BITS_PER_WORD
bits_missing = smallest_n_bits - n_bits
bits += checksum_bits[0:bits_missing]
mnemonic = []
for i in range(0, len(bits), 11):
word_int = bits_to_int(bits[i:i+11])
mnemonic.append(words[word_int])
return mnemonic
def parse_mnemonic(mnemonic, words):
"""Parse mnemonic into bytearray using wordlist
For each word in the mnemonic, find it's 0 indexed position
in the wordlist, convert the the position into a bit array
(lest significant bit first) and concatenate all such bit
arrays. Pad it with 7 * [0] to ensure the last bits fit into the
last byte. Convert the bit array into a byte array (least
significant bit first)
:param mnemonic: A list of words from the mnemonic
:param words: The (ordered) word list
:return: Decoded bytes
"""
bits = []
for word in mnemonic:
i = words.index(word)
bits += [i >> j & 1 for j in range(11)]
n_bits = len(bits)
# Add padding bits to ensure the last chunck has 8 bits
bits += [0] * 7
bites = []
for i in range(0, n_bits, 8):
bites.append(bits_to_int(bits[i:i+8]))
return bytearray(bites)
def run(word_file, encode, in_file, out_file, length):
with open(word_file, "r") as wordlist:
words = [word.strip() for word in wordlist.readlines() if word.strip()]
if encode:
with open(in_file, "rb") as in_file:
bites = bytearray(in_file.read())
mnemonic = create_mnemonic(bites, words)
print("\n".join(mnemonic))
else:
mnemonic = []
for line in sys.stdin.readlines():
mnemonic += line.split()
bites = parse_mnemonic(mnemonic, words)
with open(out_file, "wb") as out_file:
out_file.write(bites[:length])
if __name__ == "__main__":
args = parse_args()
run(
args.wordlist,
not args.decode,
args.input,
args.output,
args.length if args.decode else None,
)