Spark Machine Learning Example with Scala

Spark Machine Learning Example

In this Apache Spark Machine Learning example, Spark MLlib is introduced and Scala source code analyzed.  This post and accompanying screencast videos demonstrate a custom Spark MLlib Spark driver application.  Then, the Spark MLLib Scala source code is examined.  Many topics are shown and explained, but first, let’s describe a few machine learning concepts.

Machine Learning Key Concepts

What is machine learning?

Machine learning is creating and using models that are learned from data.  You might also hear machine learning referred to as predictive modeling or data mining.  

What are three examples of machine learning?

  • spam prediction
  • fraudulent credit card transaction prediction
  • a product or advertisement recommendation engine

There are two types of machine learning models: supervised and unsupervised.  Supervised models contain a set of data labeled with correct answers while unsupervised does not contain labeling.

Examples of Supervised machine learning models

  • k-nearest neighbors: predict how a person might vote if you know how their neighbors are voting
  • naive bayes: determine if an incoming email is spam
  • linear regression: try to determine if two variables are correlated
  • decision trees: use a structure to represent a number of possible decision paths and an outcome for each path

Examples of Unsupervised machine learning models

  • clustering – works with unlabeled data and attempts to “cluster” it.  For example, a data set showing where millionaires live has clusters in places like Beverly Hills and Manhattan
  • Latent Dirichlet Analysis (LDA) – natural language processing commonly used to identify common topics in text or a set of documents
  • neural networks: handwriting recognition and face image detection

When building models used to make predictions, we often train a model based on an existing data set.  The model may be re-trained as more and more training data set becomes available. For example, we would re-train a recommendation engine based on collaborative filtering as we learned more about the events which led to product sales or targeted engagement metrics.

Apache Spark Machine Learning Example

Let’s show a demo of an Apache Spark machine learning program.  In the following demo, we begin by training the k-means clustering model and then use this trained model to predict the language of an incoming text stream from Slack.

This example is built upon a previous Apache Spark Streaming tutorial which streams data from a Slack team site.  See Resources section below for links.

But, let’s move forward with the demo:

Spark Machine Learning Scala Source Code Review

Now that we have the demo in mind, let’s review the Spark MLLib relevant code.  Again, the links to source code may be found in the Resources section below.  Let’s start with the entry into our Spark Machine Learning example and what was called during spark-submit deploys in the demo, SlackMLApp:

object SlackMLApp {

  object Config {
    @Parameter(names = Array("-st", "--slackToken"))
    var slackToken: String = null
    @Parameter(names = Array("-nc", "--numClusters"))
    var numClusters: Int = 4
    @Parameter(names = Array("-po", "--predictOutput"))
    var predictOutput: String = null
    @Parameter(names = Array("-td", "--trainData"))
    var trainData: String = null
    @Parameter(names = Array("-ml", "--modelLocation"))
    var modelLocation: String = null

  def main(args: Array[String]) {
    new JCommander(Config, args.toArray: _*)
    val conf = new SparkConf().setAppName("SlackStreamingWithML")
    val sparkContext = new SparkContext(conf)
    // optain existing or create new model
    val clusters: KMeansModel =
      if (Config.trainData != null) {
        KMeanTrainTask.train(sparkContext, Config.trainData, Config.numClusters, Config.modelLocation)
      } else {
        if (Config.modelLocation != null) {
          new KMeansModel(sparkContext.objectFile[Vector](Config.modelLocation).collect())
        } else {
          throw new IllegalArgumentException("Either modelLocation or trainData should be specified")

    if (Config.slackToken != null) {, Config.slackToken, clusters, Config.predictOutput)


The code above contains the main method and is called from spark-submit.  As you can see, we will either train a new model or use an existing model when running the SlackStreamingTask.  It depends on the incoming command line arguments such as trainData, modelLocation and slackToken.

In this Spark machine learning example source code analysis, next, we focus on 1) the code used to train the model in KMeanTrainTask and 2) using the model to make predictions in SlackStreamingTask.

First, let’s open the relevant portion KMeanTrainTask

  def train(sparkContext: SparkContext, trainData: String, numClusters: Int, modelLocation: String): KMeansModel = {
    if (new File(modelLocation).exists) removePrevious(modelLocation)

    val trainRdd = sparkContext.textFile(trainData)

    val parsedData =  
    // if we had a really large data set to train on, we'd want to call an action to trigger cache.
    val model = KMeans.train(parsedData, numClusters, numIterations)

    sparkContext.makeRDD(model.clusterCenters, numClusters).saveAsObjectFile(modelLocation)
    val example = trainRdd.sample(withReplacement = false, 0.1).map(s => (s, model.predict(Utils.featurize(s)))).collect()
    println("Prediction examples:")


When calling train we attempt to remove any previously saved model in removePrevious. (removePrevious isn’t shown because it’s not relevant for our focus on machine learning with Apache Spark.)  So, let’s set up a new RDD called trainRdd.  Since textFile accepts a String argument of a directory, it will read all files contained in the directory which we called with “input”.

Next, we must convert the elements (rows of text) in the RDD to a format suitable for KMeans.  We do this by calling Utils.featurize which looks like this:

object Utils {

  val NUM_DEMENSIONS: Int = 1000

  val tf = new HashingTF(NUM_DEMENSIONS)

   * This uses min hash algorithm to transform 
   * string to vector of double, which is required for k-means
  def featurize(s: String): Vector = {


Now, if we go back to our KMeansTrain task object, we’re in a position to train our model using KMeans.train function with parsedData and numClusters and numIterations.  Afterward, we save the model and send a few example predictions of clustering to the console by iterating over example and sending to println.

Now that we have a model trained, let’s see SlackStreamingTask

object SlackStreamingTask {

  def run(sparkContext: SparkContext, slackToken: String, clusters: KMeansModel, predictOutput: String) {
    val ssc = new StreamingContext(sparkContext, Seconds(5))
    val dStream = ssc.receiverStream(new SlackReceiver(slackToken))
    val stream = dStream //create stream of events from the Slack... but filter and marshall to JSON stream data 
      .filter(JSON.parseFull(_).get.asInstanceOf[Map[String, String]]("type") == "message") // get only message events
      .map(JSON.parseFull(_).get.asInstanceOf[Map[String, String]]("text")) // extract message text from the event

    val kmeanStream = kMean(stream, clusters) // create K-mean model
    kmeanStream.print() // print k-mean results. It is pairs (k, m), where k - is a message text, m - is a cluster number to which message relates
    if (predictOutput != null) {
      kmeanStream.saveAsTextFiles(predictOutput) // save to results to the file, if file name specified

    ssc.start() // run spark streaming application
    ssc.awaitTermination() // wait the end of the application

  * transform stream of strings to stream of (string, vector) pairs and set this stream as input data for prediction
  def kMean(dStream: DStream[String], clusters: KMeansModel): DStream[(String, Int)] = { => (s, Utils.featurize(s))).map(p => (p._1, clusters.predict(p._2))) 


The Spark MLlib code which is making clustering predictions with a previously saved model is clusters.predict.  Before it is called, we map over the DStream and use featurize again in order to use with predict.  We are returning a DStream with the original text received from Slack and the predicted cluster.

If the Spark driver program is called with the predictOutput input, the output is saved as text files.

Here’s another screencast which I’m describing the code in more detail.


Source code:

Spark ML tutorials 

Spark tutorials in Scala

Background on Machine Learning:

Spark MLlib:

Spark ML and Spark MLLib documentation

Featured image credit:

Spark Machine Learning – Chapter 11 Machine Learning with MLlib

Spark Machine Learning

Spark Machine Learning is contained with Spark MLlib.  Spark MLlib Spark’s library of machine learning (ML) functions designed to run in parallel on clusters.  MLlib contains a variety of learning algorithms. The topic of machine learning itself could fill many books, so instead, this chapter explains ML in Apache Spark.

This post is an excerpt from Chapter 11 Spark Machine Learning in our Apache Spark book Learning Spark Summary


MLlib invokes various algorithms on RDDs. As an example, MLlib could be used to identify spam through the following:

1) Create an RDD of strings representing email.

2) Run one of MLlib’s feature extraction algorithms to convert text into an RDD of vectors.

3) Call a classification algorithm on the RDD of vectors to return a model object to classify new points.

4) Evaluate the model on a test dataset using one of MLlib’s evaluation functions.

Some classic ML algorithms are not included with Spark MLib because they were not designed for parallel computations. MLlib contains several recent research algorithms for clusters, such as distributed random forests, K-means | |, and alternating least squares. MLlib is best suited for running machine learning algorithms on large datasets.

Machine Learning Basics

Machine learning algorithms try to predict or make decisions based on training data.  There are multiple types of learning problems, including classification, regression, or clustering.  All of which have different objectives.

All learning algorithms require defining a set of features for each item.  Then this set of features is sent into the learning function. For example, for an email, a set of features might include the server it comes from, the number of mentions of the word free, or the color of the text.

Pipelines often train multiple versions of a model and evaluate each one. To do this, separate the input data into “training” and “test” sets.

A very simple program for building a spam classifier in python is shown.  The code and data files are available in the book’s Git repository.

Data Types

MLlib contains a few specific data types including Vector, LabeledPoint, Rating, and various Model classes

Working with Vectors

Vectors come in two flavors: dense and sparse. Dense vectors store all their entries in an array of floating-point numbers.  Sparse vectors are usually preferable because they store nonzero values and their indices.


The key algorithms available in MLlib are Feature Extraction, TF-IDF, Scaling, Normalization, Word2Vec


Various statistics functions are available.

Classification and Regression

Classification and regression are two common forms of supervised learning where the difference between the two is the type of variable predicted.

Linear regression is one of the most common methods for regression.  This method predicts the output variable as a linear combination of the features.

Logistic regression is a binary classification method which identifies a linear separating plane between positive and negative examples such as Support Vector Machines (SVM).

Naive Bayes is a multiclass classification algorithm scoring how well each point belongs in each class based on a linear function of the features.

Decision trees are a flexible model used for both classification and regression.


Clustering is an unsupervised learning task involving the grouping of objects into clusters of high similarity.

MLlib includes the popular K-means algorithm for clustering, as well as a variant called K-means | |.

Collaborative Filtering and Recommendation

Collaborative filtering is a technique for recommender systems.  Users’ ratings and interactions with various products are used to make recommendations.

MLlib includes Alternating Least Squares (ALS) which is a popular algorithm for collaborative filtering.

Dimensionality Reduction

The main technique for dimensionality reduction used by the machine learning community is principal component analysis (PCA), but MLlib also provides the lower-level singular value decomposition (SVD) primitive.

Model Evaluation

Many learning tasks may be addressed with different models.

Tips and Performance Considerations

The following are suggested to consider: feature preparation, algorithm configuration, RDD caching, recognizing sparsity and level of parallelism.

Pipeline API

Based on the concept of pipelines, starting in Spark 1.2, MLlib is adding a new, higher-level API for machine learning. The pipeline API is similar to the one found in SciKit-Learn.


This post is an excerpt from Chapter 11 Spark Machine Learning in our Apache Spark book Learning Spark Summary

Featured image photo credit