Scala code to generate Markov chains from a text file

(Technical note: the following code is not quite correct, as it can lead to 0 probability for picking certain words, which isn’t right… the correct approach is to have a uniform prior (i.e. add 1 to everything). I’ll update it soon. Learning is fun!)

As a followup to my last post on Markov chains, I thought I’d post some code I wrote to generate those sorts of word-chains given a training .txt file.

I will be putting this example along with some of my helper libraries and other introductory machine learning examples up on my GitHub in the near future. I’ve been self-studying machine learning for a while now, and writing little models and snippets of code here and there, but I thought it would be a good exercise to have a bit more discipline about it and try to make public a set of worked examples, factored into a mini-library of sorts.

I’ve been using Project Gutenberg as a nice free source of training text, and the code below is configured to use a local copy of War and Peace as the training data. Just set the path to point to wherever you download the text. If you can’t get this stuff to work keep on the lookout for some posts about setting up Scala + IntelliJ that I’ve been putting off writing…

Here’s some sample output:

the mob pointing to act and earned him princess mary listened to 
thedisordered battalion three and wait said something to be a small
officer said his coat over the village of countbezukhov and yellow
legs went onall the strange voice this a memorandum is so on their
mastersand he moved on the country seat of the dorogomilov gates of
oldenburgand there to commitmurder he had saved by the young
armenian family had taken from the meeting and the morning
alpatychdonned a quiet gentleirony because it offended by its all
that a third of moscow no farther end innothinghe was such

Nice.

Anyhow, here’s the code (apologies for the funky formatting, I know I should just use Gists or figure out a blog plugin or something, but it’s haaaaard. Wah.)

object StandaloneMarkovChainGenerator {

  // replace this with a path containing your training text
  val localPath = "C:\\Users\\Luke\\Desktop\\Files For Machine Learning\\";

  def readLocalTextFile(path: String): String =
    io.Source.fromFile(localPath + path).mkString

  def getWordSequenceFromString(str: String): Seq[String] =
    str.toLowerCase filter { c => c.isLetter || c.isSpaceChar } split "\\s+"
  
  def getWordSequenceFromLocalFile(fileName: String): Seq[String] =
    getWordSequenceFromString(readLocalTextFile(fileName))

  def getConsecutivePairs[T](items: Seq[T]): Seq[(T, T)] =
    items zip (items drop 1)

  def getCountsDouble[T](items: Seq[T]): Map[T, Double] =
    items groupBy identity mapValues { _.length.asInstanceOf[Double] }

  type Dist[+T] = () => T

  def getWeightedCasesDistribution[T](
    weightedCases: Seq[(T, Double)]): Dist[T] = {

    val cases = weightedCases map { _._1 }
    val weights = weightedCases map { _._2 }
    val summedWeights = weights .scan(0d) { _ + _ } drop 1
    val sumOfWeights = summedWeights.last
    val probs = summedWeights map { _ / sumOfWeights }
    val casesAndProbs = cases zip probs

    { () =>
      val roll = scala.math.random
      casesAndProbs find {
        case (_, prob) => prob > roll
      } map { _._1 } getOrElse (sys.error("Impossible!"))
    }
  }

  def generateMarkovChainFromFile(
    fileName: String, startWord: String = "the"): Stream[String] = {

    val words = getWordSequenceFromLocalFile(fileName)
    val pairs = getConsecutivePairs(words)
    val pairsWithCounts = getCountsDouble(pairs)

    val wordsToFollowingWordsAndCounts =
      pairsWithCounts.toSeq
        .map { case ((a, b), num) => (a, (b, num)) }
        .groupBy { case (a, (b, num)) => a }
        .mapValues { _ map { case (a, (b, num)) => (b, num) } }

    val sortedByFrequency =
      wordsToFollowingWordsAndCounts mapValues { _ sortBy { _._2 } }

    var wordsToNextWordDists = Map(): Map[String, Dist[String]]

    def pickNext(word: String): String = {
      val possibleNextWordsAndWeights = sortedByFrequency(word)
      if (possibleNextWordsAndWeights.isEmpty) { return startWord }
      wordsToNextWordDists = wordsToNextWordDists updated
        (word, getWeightedCasesDistribution(possibleNextWordsAndWeights))
      wordsToNextWordDists(word)()
    }

    Stream.iterate(startWord)(pickNext)
  }

  def main(args: Array[String]) {
    // you can replace this with a different book if you'd like
    val filePath = "MarkovChain\\WarAndPeace.txt"
    val chain = generateMarkovChainFromFile(filePath)

    println(chain take 100 reduce { _ + " " + _ })
  }
}
Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

%d bloggers like this: