As we discussed before, Spark-Scala API is the most widely used API and is a production-ready solution.


What is Scala?

Scala is a functional programming language. It is a powerful programming concept that offers more robustness than imperative or object programming. Scala is not purely functional but would rather be in a grey area between the two.

Before diving in Scala for Machine Learning, I would recommand simply following the basic overview of Scala from the official documentation. It takes 10 minutes andd introduces the basics of Scala and some syntax.

Spark-Scala API might be close to 10 times faster than PySpark, since Spark is written in Scala. Scala is a language developed by EPFL and become really popular a few years ago. It is quite similar to Java in some parts.

One of the main feature of Scala is the function compostion. You might be used to it in PySpark, an this is where it comes from. For example, you can apply sequential functions to a dataframe this way.

val newDF = df

In this example, every function (groupBy, dropDuplicates…) is applied to input the dataframe, and each function returns a new data frame.

Scala for Machine Learning

Through Spark, Scala is a great ML tool that data scientists should master. It’s a great way to use simple high-level APIs for ML and apply it at scale.

In the example below, we’ll try to predict the price of some houses given several features.

Basic Exploration

First, open a Jupyter Notebook with a Spylon Kernel (see my previous article) or whatever IDE you’d like, and import the following packages:

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._

Then, create your new Spark Session:

val spark = SparkSession.builder

We then read the input CSV file:

val df ="header", "true").option("inferSchema", "true").csv("data/house.csv")
|_c0|             address|                info|           z_address|bathrooms|bedrooms|finishedsqft|lastsolddate|lastsoldprice| latitude|  longitude|   neighborhood|totalrooms|    usecode|yearbuilt|zestimate|zindexvalue|zipcode|       zpid|
|  2|Address: 1160 Mis...| San FranciscoSal...|1160 Mission St U...|      2.0|     2.0|      1043.0|  02/17/2016|    1300000.0|37.778705|-122.412635|South of Market|       4.0|Condominium|   2007.0|1167508.0|    975,700|94103.0|8.3152781E7|
|  5|Address: 260 King...| San FranciscoSal...|260 King St UNIT 475|      1.0|     1.0|       903.0|  02/17/2016|     750000.0|37.777641|-122.393417|South of Market|       3.0|Condominium|   2004.0| 823719.0|    975,700|94107.0|6.9819817E7|

The layout of these dataframes is not always the best. We can print the columns using:

Array(_c0, address, info, z_address, bathrooms, bedrooms, finishedsqft, lastsolddate, lastsoldprice, latitude, longitude, neighborhood, totalrooms, usecode, yearbuilt, zestimate, zindexvalue, zipcode, zpid)

SQL Exploration

If you are familiar with SQL, it is extremely simple to run SQL queries using Spark-Scala. You must first register the template dataframe as a table, and then use spark.sql to run queries.

Let’s print the average price of the houses in the database:

val sqlDF = spark.sql("SELECT AVG(lastsoldprice) FROM housing")

We can also display the most popular ZIPCodes among our database:

val sqlDF2 = spark.sql("SELECT zipcode, COUNT(*) AS `num` FROM housing GROUP BY zipcode ORDER BY num DESC")
only showing top 5 rows

You can make sure that the type of the data has been well identified during the inferSchema at import.

 |-- _c0: integer (nullable = true)
 |-- address: string (nullable = true)
 |-- info: string (nullable = true)
 |-- z_address: string (nullable = true)
 |-- bathrooms: double (nullable = true)
 |-- bedrooms: double (nullable = true)
 |-- finishedsqft: double (nullable = true)
 |-- lastsolddate: string (nullable = true)
 |-- lastsoldprice: double (nullable = true)
 |-- latitude: double (nullable = true)
 |-- longitude: double (nullable = true)
 |-- neighborhood: string (nullable = true)
 |-- totalrooms: double (nullable = true)
 |-- usecode: string (nullable = true)
 |-- yearbuilt: double (nullable = true)
 |-- zestimate: double (nullable = true)
 |-- zindexvalue: string (nullable = true)
 |-- zipcode: double (nullable = true)
 |-- zpid: double (nullable = true)


There are many features which we won’t use, such as the address since this is a string input and the aim is not to dive in Natural Language Processing, but there is a feature usecode which is categorial and should be transformed to numerical.

To do so, we define a string indexer which matches an input string category with a given numeric index.

val indexer = new StringIndexer().setInputCol("usecode").setOutputCol("usecode2")

val df2 =

In order to use SparkML, we should build a column which contains all features we are going to use to make the prediction. Think about it as grouping all features in a list, and creating a single column called “features”. This is done with a Vector Assembler.

val assembler = new VectorAssembler().
       setInputCols(df2.drop("lastsoldprice", "zindexvalue", "_c0", "address", "info", "z_address", "lastsolddate", "neighborhood", "usecode").columns).

val df3 = assembler.transform(df2)

We can now check the different columns of the dataframe:

Array[String] = Array(_c0, address, info, z_address, bathrooms, bedrooms, finishedsqft, lastsolddate, lastsoldprice, latitude, longitude, neighborhood, totalrooms, usecode, yearbuilt, zestimate, zindexvalue, zipcode, zpid, usecode2, features)

We do indeed have a new column! The last step before builing our model is to split our data into train and test:

val Array(train, test) = df3.randomSplit(Array(0.8, 0.2), seed = 30)

Build a model

We will use a simple Gradient Boosted Tree Regression model with default parameter and at most 10 iterations.

val gbt = new GBTRegressor()

Then, fit the model:

val model =

Make predictions

Let’s now make some predictions on the test set and assess the performance of our model:

val predictions = model.transform(test)"prediction", "lastsoldprice").show(5)
|        prediction|lastsoldprice|
|1613879.2210632167|    1530000.0|
|1389284.2393296428|    1440000.0|
| 1369447.861598761|    1700000.0|
| 770113.3958960483|     700000.0|
|1062512.6163005617|    1525000.0|

To evaluate the performance of a model, we should simply call a regression evaluator. We can pick the metric of our choice, for example the Root Mean Squared Error:

val evaluator = new RegressionEvaluator()

val rmse = evaluator.evaluate(predictions)
println(s"Root Mean Squared Error (RMSE) on test data = $rmse")
Root Mean Squared Error (RMSE) on test data = 695538.99

Conclusion: I hope this first approach to Spark-Scala was clear and helpful. I’d be happy to answer any question you might have in the comments section.