RSA encryption in Haskell
It feels good to be back in the swing of learning. Pretty much my entire Summer was spent in Germany, doing an internship at Technische Universität Ilmenau. It was fun, but after three months of gluing together OpenCV and Android using JNI, I'm quite happy to be studying again, in an environment where I can choose my own tools and languages (most of the time).
Let me get one thing straight before I go any further: I luuuurve functional programming. My first foray into FP was about 6 years ago, when I was helping my girlfriend at the time with her first year university coursework. At Edinburgh University, the first language they teach their undergrads is Haskell, and in my opinion there is a strong case for this to be encouraged at other universities. Learning Haskell was an eye-opening experience for me, as a young programmer. This was the first time I had tried my hand at any non-imperative paradigm; I had never so much as glanced at Java tutorial back then, and had only used Python and C up to that point. Nevertheless, from my first toy factorial program, I was hooked, and though my love affair with Haskell has been a turbulent one, it has endured longer than the romantic relationship that spawned it.
I think a lot of green-eared first year CS undergrads are in a similar position to the one I was in when I first learned Haskell. I often take prospective students on tours of St Andrews University's CS department, and I always make a point of asking them what languages they've learned. Invariably, most of them learned Visual Basic at school, there are always a couple of Java or Python, but there has only been one prospective student I've spoken to in the last 2 years who has even heard of Haskell. So we see, Haskell is common to newbie CS students in its absence. No matter their imperative, object-oriented, or event-driven experience, the majority of new CS students have never programmed functionally, or even declaratively (though this is changing with the rise in popularity of JSON) before. Teaching Haskell as a first language to students sets a common denominator which brings almost everyone to the same level at the beginning of a course. Hopefully, it will also instil some good habits in the way they write programs too, such as thinking about the problem rather than the code. I could write accounts of some of my contemporaries in 3rd year CS to illustrate why this would be useful: the person who has all the skills to solve problems algorithmically, but cannot translate this into the simplest Java code; the person who creates unnecessary variables for everything; the person who writes one big messy chunk of spaghetti code. I'm sure anyone reading this, whether a student or working in the industry, will recognise these descriptions. I would argue that coding everything in a pure FP language such as Haskell for half a year can eliminate these habits in young programmers.
I've gone on a bit of a tangent here to what I was originally going to talk about, namely a project I had to do as part of the coursework for the Data Encoding course I'm doing this semester: RSA encryption (http://en.wikipedia.org/wiki/RSA_(algorithm)). Upon a tentative "yes" from the lecturer when asked whether I could use Haskell for coursework, I set about getting back into the FP mindset. It had been about a year since I last used Haskell, and my memory for syntax is terrible, so I had my trusty copy of The Craft of Functional Programming (http://www.amazon.co.uk/Haskell-Functional-Programming-International-Computer/dp/0201342758) open by my side. It's a really nice book, but I'll admit, it's the only Haskell book I've read; other options are Real World Haskell (http://book.realworldhaskell.org/read/), and Learn You a Haskell (http://learnyouahaskell.com/chapters), both freely available to read online.
Incidentally, I came across a good paper the other day about RSA in Haskell, which likely explains it much better than I could. Have a gander at is here: http://www.citidel.org/bitstream/10117/120/13/paper.pdf. Instead of an in-depth explanation of RSA and how I implemented it, I would like to use my code to illustrate some of the things I like (and don't like) about Haskell. The first is something I love: the ease from which you can go from thinking algorithmically or mathematically to having working code. It's often said that writing in a language like Python or Ruby is like writing executable pseudocode, and I would argue that Haskell displays this quality as well. Take, for example, the Extended Euclidian Algorithm (http://en.wikipedia.org/wiki/Extended_Euclidean_algorithm), one of the most implemented algorithms in computer science, and one that is necessary for RSA. From the Wikipedia article I just linked, the pseudocode for a recursive implementation is:
function extended_gcd(a, b) if b = 0 return (1, 0) else (q, r) := divide (a, b) (s, t) := extended_gcd(b, r) return (t, s - q * t)
As you can probably see, this fits Haskell's programming model so well that the code practically writes itself! Here is the code for this algorithm from my RSA implementation:
-- Extended Euclidian algorithm using the recursive method, returning (gcd, x, y) eea :: (Integral a) => a -> a -> (a, a, a) eea a b | b == 0 = (a, 1, 0) | otherwise = (d, t, s - q * t) where (q, r) = a `divMod` b (d, s, t) = eea b r
In fact, the translation between algorithm and executable code is so straightforward, I was using the code shown here to explain the algorithm to a fellow student who had not seen a single line of Haskell before. The only explanation I had to give of syntax was "the '|' character means 'if'".
So we've seen one example of the good, let's move onto the bad. One of Haskell's strengths may also be considered a weakness: a short mental distance between algorithm and code makes it easy to implement something badly. I guess this isn't a failing of the language really -- bad code is bad code, no matter what it's written in -- but the sheer number of ways of implementing a function in Haskell often leads to bad code through laziness (in a bad "I can't be bothered to take the dog for a walk" way, rather than a good "tail call optimisation" way ;) ). Another example from my code, this time the part that handles modular exponentiation:
-- Modular exponentiation by squaring (using Montgomery's ladder). Return x^n (mod m) mexp :: Integer -> Integer -> Integer -> Integer mexp x n m | n == 0 = 1 | otherwise = fst (foldl (mexp' m) (x, x ^ 2) [ testBit n (k - b - 2) | b <- [0 ..(k - 2)] ]) where k = ceiling ( logBase 2 (fromIntegral (n + 1)) ) mexp' :: Integer -> (Integer, Integer) -> Bool -> (Integer, Integer) mexp' m xs b | b == False = ((x1 ^ 2) `mod` m, (x1 * x2) `mod` m) | otherwise = ((x1 * x2) `mod` m, (x2 ^ 2) `mod` m) where x1 = fst xs x2 = snd xs
I look at this and I cringe. In my 20/20 hindsight, I can see much simpler ways of implementing this, with judicious use of (mod 2) arithmetic. Nevertheless, it illustrates my point perfectly: folding across a list comprehension yielding the binary representation of a number, in order to do one operation for a '1' and another for a '0', for every bit, is clunky. It gets the job done, but it was literally the first idea that came to my mind, on reading a description of Montgomery's Ladder (http://en.wikipedia.org/wiki/Exponentiation_by_squaring#Montgomery.27s_ladder_technique). If/when I decide to clean this code up a bit, these two functions will be first in line for overhaul.
After such tomfoolery, let's end on an upbeat note. There are a few things I was thinking of including here, but I'll save them for another post. Testing in Haskell is great fun. As with other interpreted languages, calling individual functions is a doddle, but one advantage (as I hinted at earlier) is the nature of functions written in Haskell to be short, simple, and specific. Having a chunky "do-everything" function just isn't a sensible option. This makes testing easy, as functionality remains discrete among functions. At St Andrews University, a lot of the lecturers encourage a test-driven approach. A unit testing library such as the fantastic HUnit (http://hunit.sourceforge.net/) is great for TDD, and I can honestly say that the satisfaction in writing a handful of one-liner test functions beats the hell out of tedious method testing in Java.
Anyhoos, that's all I really want to say about this for now. Maybe a follow-up will, er, follow up, but for now here's my code. And yes, before you say anything, I know that encoding letter-by-letter is crap. Pro-tip: don't leave coursework until the last minute ;).
import Data.Bits import System.Random import Data.Char --Keys, in the form n, k (where k is i for the public key, and j for the private key) data Key = Public Integer Integer | Private Integer Integer deriving (Eq, Ord, Show) -- Extended Euclidian algorithm using the recursive method, returning (gcd, x, y) eea :: (Integral a) => a -> a -> (a, a, a) eea a b | b == 0 = (a, 1, 0) | otherwise = (d, t, s - q * t) where (q, r) = a `divMod` b (d, s, t) = eea b r -- Modular multiplicative inverse for a (mod m) mminv :: (Integral a) => a -> a-> a mminv a m | gcd /= 1 = error "Number doesn't have a multiplicative inverse for this modulus!" | otherwise = x `mod` m where (gcd, x, _) = eea a m -- Modular exponentiation by squaring (using Montgomery's ladder). Return x^n (mod m) mexp :: Integer -> Integer -> Integer -> Integer mexp x n m | n == 0 = 1 | otherwise = fst (foldl (mexp' m) (x, x ^ 2) [ testBit n (k - b - 2) | b <- [0 ..(k - 2)] ]) where k = ceiling ( logBase 2 (fromIntegral (n + 1)) ) mexp' :: Integer -> (Integer, Integer) -> Bool -> (Integer, Integer) mexp' m xs b | b == False = ((x1 ^ 2) `mod` m, (x1 * x2) `mod` m) | otherwise = ((x1 * x2) `mod` m, (x2 ^ 2) `mod` m) where x1 = fst xs x2 = snd xs -- Generate public and private keys using the multiplicative inverse of i, mod phi generateKeys :: Integer -> Integer -> Integer -> (Key, Key) generateKeys p q i | gcd /= 1 = error "Public exponent i is not coprime with phi!" | otherwise = (Public n i, Private n j) where n = p * q phi = (p - 1) * (q - 1) (gcd, _, _) = eea i phi j = mminv i phi -- Code or decode an integer, given a public/private key rsaCoder :: Key -> Integer -> Integer rsaCoder (Public n k) x = mexp x k n rsaCoder (Private n k) x = mexp x k n -- Primality tester from http://www.haskell.org/haskellwiki/Testing_primality, but using my own mexp function instead of theirs -- BEGIN -- -- (eq. to) find2km (2^k * n) = (k,n) find2km :: Integral a => a -> (a,a) find2km n = f 0 n where f k m | r == 1 = (k,m) | otherwise = f (k+1) q where (q,r) = quotRem m 2 -- n is the number to test; a is the (presumably randomly chosen) witness millerRabinPrimality :: Integer -> Integer -> Bool millerRabinPrimality n a | a <= 1 || a >= n-1 = error $ "millerRabinPrimality: a out of range (" ++ show a ++ " for "++ show n ++ ")" | n < 2 = False | even n = False | b0 == 1 || b0 == n' = True | otherwise = iter (tail b) where n' = n-1 (k,m) = find2km n' b0 = mexp a m n -- modified this line b = take (fromIntegral k) $ iterate (squareMod n) b0 iter [] = False iter (x:xs) | x == 1 = False | x == n' = True | otherwise = iter xs squareMod :: Integral a => a -> a -> a squareMod a b = (b * b) `rem` a -- END -- -- Use the Miller-Rabin method of primality testing, with a witness of 100 (i.e. a non-prime probability of 2^(-100), -- according to http://snippets.dzone.com/posts/show/4200) primeTest :: Integer -> Bool primeTest x = millerRabinPrimality x 100 -- Generate an n-bit random prime number getPrime :: Integer -> IO Integer getPrime n = do r <- randomRIO (2 ^ n, (2 ^ (n + 1)) - 1) if (primeTest r) then return r else getPrime n -- Encode a string byte-wise as an list of RSA-encrypted integers (this is not a good way of doing it, as frequency -- analysis can easily be performed for frequent characters encode:: String -> Key -> [Integer] encode s k = [rsaCoder k (toInteger $ ord i) | i <- s] -- Decode a list of RSA-encrypted integers byte-wise to a string decode:: [Integer] -> Key -> String decode s k = [chr $ fromInteger $ rsaCoder k i | i <- s] mane :: IO() mane = do p <- getPrime 256 q <- getPrime 256 i <- getPrime 256 putStr $ "p: " ++ (show p) ++ " " ++ show (primeTest p) ++ "\n" putStr $ "q: " ++ (show q) ++ " " ++ show (primeTest q) ++ "\n" putStr $ "i: " ++ (show i) ++ " " ++ show (primeTest i) ++ "\n" let keys = generateKeys p q i let pub = fst keys let priv = snd keys putStr $ show pub putStr "\n" putStr $ show priv putStr "\nType the text to encode:\n" plaintext <- getLine putStr "\n" let encrypted = encode plaintext pub putStr $ "Encrypted:\n" ++ (show encrypted) ++ "\n" let decrypted = decode encrypted priv putStr $ "Decrypted:\n" ++ (show decrypted) ++ "\n"