Tuesday, March 31, 2015

Moneyball Predictions using Linear Regression in Python


I am taking the course The Analytics Edge from edX. In terms of coverage of techniques, its fairly elementary. However, it has a somewhat more practical, problem-solving orientation that I am finding very interesting. Each week, the course teaches a particular machine learning technique with multiple case studies in R. I don't particularly enjoy the R part (although I am gradually getting used to its quirkiness), but I do enjoy the depth of the analysis as they go through each case study. As an example, I describe their coverage of Linear Regression using the Moneyball example.

The Moneyball exercise seeks to find what the Oakland A's need to do to get to the playoffs in the year 2002. The exercise attempts to confirm the estimates made by Paul DePodesta, the analytics brain behind the Oakland A's, using Linear Regression. I will do the exercise in Python, since I feel it reads as well, if not better, than the R version from the course.

The data is provided by the course, and comes from Baseball Reference.com. Our first step is to read it into a Pandas dataframe.

1
2
3
4
5
6
7
8
9
from __future__ import division
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

baseball = pd.read_csv("../../../data/baseball.csv")

moneyball = baseball[baseball["Year"] < 2002]

The data has 1,232 rows and represents statistics for around 40 teams from 1962 to 2012. Since we are looking to confirm DePodesta's estimates, we only look at the data prior to 2002, about 902 rows, and store this in the DataFrame moneyball. Here is a sample of the moneyball dataframe.

Team League Year RS RA W OBP SLG BA Playoffs RankSeason RankPlayoffs G OOBP OSLG _TeamID _RDiff
330 ANA AL 2001 691 730 75 0.327 0.405 0.261 0 NaN NaN 162 0.331 0.412 0 -39
331 ARI NL 2001 818 677 92 0.341 0.442 0.267 1 5 1 162 0.311 0.404 1 141
332 ATL NL 2001 729 643 88 0.324 0.412 0.260 1 7 3 162 0.314 0.384 2 86
333 BAL AL 2001 687 829 63 0.319 0.380 0.248 0 NaN NaN 162 0.337 0.439 3 -142
334 BOS AL 2001 772 745 82 0.334 0.439 0.266 0 NaN NaN 161 0.329 0.393 4 27

Using this data, we plot the number of wins for each team, and color code the cases where the team went to the playoffs (in red) vs not (in black). The following code does this.

1
2
3
4
5
6
7
8
9
team_idx = {v:k for (k, v) in enumerate(moneyball["Team"].unique())}
moneyball["_TeamID"] = [team_idx[x] for x in moneyball["Team"]]

moneyball_playoff = moneyball[moneyball["Playoffs"]==1]
plt.scatter(moneyball_playoff["W"], moneyball_playoff["_TeamID"], color="r")
moneyball_nonplayoff = moneyball[moneyball["Playoffs"]==0]
plt.scatter(moneyball_nonplayoff["W"], moneyball_nonplayoff["_TeamID"], color="k")
plt.vlines(95, 0, len(team_idx), colors="b", linestyles="solid")
plt.yticks([])

This produces the following scatter plot. As you can see, the vertical blue line (at 95, DePodesta's estimates) looks like it nicely separates the teams that ended up in the playoffs and those that did not.


DePodesta also estimated that a team will need to score around 135 more runs than their opponent on average per game to make the 95 wins. We can verify that using Linear Regression. There are two variables in the data RS (Runs scored, the number of runs scored by the team) and RA (Runs Allowed, the number of runs scored by the opposing team). The difference is the number of runs one team scores over another at a game.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
moneyball["_RDiff"] = moneyball["RS"] - moneyball["RA"]
plt.scatter(moneyball["_RDiff"], moneyball["W"])
plt.xlabel("Run Difference")
plt.ylabel("Wins")

win_model = LinearRegression()
win_model.fit(np.matrix(moneyball["_RDiff"]).T, moneyball["W"])
xs = [np.min(moneyball["_RDiff"]), np.max(moneyball["_RDiff"])]
ys = [win_model.predict(x) for x in xs]
plt.plot(xs, ys, 'r', linewidth=2.5)

run_diff_to_win = (95 - win_model.intercept_) / win_model.coef_[0]
print("Run difference: %.2f" % (run_diff_to_win))

The scatterplot below shows a very strong correlation between the run difference and the number of wins. We build a Linear Regression model using the Run Difference to predict the number of Wins. The red line is the regression line predicted by the model. Our model predicts a run difference of 133.49 to get 95 wins, very close to DePodesta's estimate.


The run difference is actually the difference of RS and RA. Both these variables can be predicted using other variables available in the data. In fact, one of the assertions of the Moneyball team is that these statistics can outperform human scouts.

To predict RA, we initially use 3 variables OBP (On Base Percentage), SLG (Slugging Percentage) and BA (Batting Average). The Batting Average used to be the popular measure to predict RA, but the Moneyball team found that OBP and SLG had better predictive power. Our Linear Regression looks like below.

1
2
3
4
5
6
7
8
y = np.array(moneyball["RS"])
X = np.vstack((np.array(moneyball["OBP"]), 
               np.array(moneyball["SLG"]),
               np.array(moneyball["BA"]))).T
rs_model = LinearRegression()
rs_model.fit(X, y)
print("R-squared[RS_1]: %.4f" % (rs_model.score(X, y)))
print(rs_model.intercept_, rs_model.coef_)

This model gives us an R2 score of 0.9302. However, the coefficient of BA is negative, indicating that the higher a team's batting average, the lower its chance of winning. This is clearly non-intuitive and comes about because of multi-collinearity (ie multiple variables that vary in the same direction). Removing BA results in a model with a slightly lower R2 of 0.9296.

1
2
3
4
5
6
X = np.vstack((np.array(moneyball["OBP"]), 
               np.array(moneyball["SLG"]))).T
rs_model = LinearRegression()
rs_model.fit(X, y)
print("R-squared[RS_2]: %.4f" % (rs_model.score(X, y)))
print(rs_model.intercept_, rs_model.coef_)

Similary, the Runs Allowed is built up as a Linear Regression model over OOBP (Opponent On Base Percentage) and OSLG (Opponent Slugging Percentage). The data is a bit dirty (contains missing values), so to get a tolerably good R2, we had to remove missing values. This resulted in dropping the 902 values down to 90. However, the R2 of the model is 0.9073.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
moneyball_ra = pd.DataFrame({"OOBP": moneyball["OOBP"],
                             "OSLG": moneyball["OSLG"],
                             "RA": moneyball["RA"]})
moneyball_ra = moneyball_ra.dropna(axis=0)
y = np.array(moneyball_ra["RA"])
X = np.vstack((np.array(moneyball_ra["OOBP"]), 
               np.array(moneyball_ra["OSLG"]))).T
ra_model = LinearRegression()
ra_model.fit(X, y)
print("R-squared[RA]: %.4f" % (ra_model.score(X, y)))
print(rs_model.intercept_, ra_model.coef_)

Finally, we use our models to predict the number of runs the A's will score and allow, and as a result, their chances of going to the playoffs in 2002. The inputs to the predict methods of the RS and RA models are the average team OBP, SLG, OOBP and OSLG values in 2001. We make the assumption that on average the team will be of the same quality in 2002 as it was in 2001.

1
2
3
4
5
6
7
pred_rs = rs_model.predict(np.matrix([0.339, 0.430]))
pred_ra = ra_model.predict(np.matrix([0.307, 0.373]))
pred_rd = pred_rs - pred_ra
pred_wins = win_model.predict(np.matrix([pred_rd]))
print("Predicted Runs Scored in 2002: %.2f" % (pred_rs))
print("Predicted Runs Allowed in 2002: %.2f" % (pred_ra))
print("Predicted Wins in 2002: %.2f" % (pred_wins))

The results from our models are 804 runs scored (DePodesta estimated 800-820, and the A's actually scored 800), 621.93 runs allowed (estimated 650-670, actual 653), and 100.24 wins (estimated 93-97, actual 103).

What impressed me most about this coverage was how it took an actual problem and attempted to deconstruct it into its components, then combined the results of its components to the final solution. I have tried to do justice to it, but have probably failed. To get an idea of how good this course is, you should probably take it - its still running, although halfway through already. But this is (I think) the 3rd time they are offering the course, and I am sure there will be more to come.

Saturday, March 14, 2015

Exploring Solr GeoSearch capabilities


At my previous job at Healthline, my use of Lucene and Solr were focused on its text analysis and search capabilities. While there have been various initiatives involving use of Solr's Geosearch (Spatial) capabilities, primarily in applications that involved finding medical providers (doctors, hospitals, etc), these were invariably done by others in the group. As a result, I have remained largely ignorant of what is possible using Solr's Spatial functionality. The ignorance has come back to bite me at least once recently, when I had to commit to a level of effort estimate on a project that had a Geosearch component.

I had some time last week so I decided to explore Solr's Spatial capabilities. For data, I used the free 500 US addresses dataset from Brian Dunning's website, available in comma-separated (CSV) with quoted string fields. I didn't want to use a full-blown CSV reader, so I opened it with OpenOffice Calc and converted it to tab-separated (TSV). I then took the addresses and annotated them with latitude and longitude using Google's Geocoding API. I then populated a Solr index with the annotated data, and ran some Spatial queries to understand the possibilities.

Here is the code to read the CSV file, and for each record, call Google's Geocoding API and annotate the record with the latitude and longitude. One thing to realize is that the annotation is only as good as your data. The Geocoding API is basically doing an address search - if you look at the output, you will see that it is breaking down the address into various components and trying to do a best match against various internal fields, and coming up with a best guess as to the latitude-longitude component. So in some cases, the latitude-longitude pair returned is not that of the address but of the closest matched point in the Geocoding API's database. Also some addresses cannot be mapped to a LatLon pair, they are given a LatLon of (0,0).

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
// Source: src/main/scala/com/mycompany/solr4extras/geo/LatLonAnnotator.scala
package com.mycompany.solr4extras.geo

import java.io.File
import java.io.FileWriter
import java.io.PrintWriter
import java.net.URLEncoder

import scala.io.Source

import org.codehaus.jackson.map.ObjectMapper

class LatLonAnnotator {

  val googleApiKey = "secret"
  val geocodeServer = "https://maps.googleapis.com/maps/api/geocode/json"

  val objectMapper = new ObjectMapper()
  
  def annotate(addr: String): (Double,Double) = {
    val params = Map(
      ("address", URLEncoder.encode(addr)),
      ("key", googleApiKey))
    val url = geocodeServer + "?" + 
      params.map(kv => kv._1 + "=" + kv._2)
            .mkString("&")
    val json = Source.fromURL(url).mkString
    val root = objectMapper.readTree(json)
    try {
      val location = root.path("results").get(0)
        .path("geometry").path("location")
      val lat = location.path("lat").asDouble
      val lon = location.path("lng").asDouble
      (lat, lon)
    } catch {
      case e: Exception => (0.0D, 0.0D)
    }
  }
  
  def batchAnnotate(infile: File, outfile: File): Unit = {
    val writer = new PrintWriter(new FileWriter(outfile), true)
    val lines = Source.fromFile(infile)
      .getLines()
      .filter(line => !line.startsWith("first_name"))
      .foreach(line => {
        val cols = line.split("\t")
        val fname = cols(0)
        val lname = cols(1)
        val company = cols(2)
        val address = cols(3)
        val city = cols(4)
        val state = cols(6)
        val zip = cols(7)
        val apiAddress = List(address, city, state)
          .mkString(", ") + " " + zip
        Console.println(apiAddress)
        val latlon = annotate(apiAddress)
        writer.println(List(fname, lname, company, address, 
          city, state, zip, latlon._1, latlon._2)
          .mkString("\t"))
        Thread.sleep(1000) // sleep 1s between calls to Google API
      })
    writer.flush()
    writer.close()
  }
}

You will need a Google API key for the Geocoding API. The free service is quite generous - they give you 5,000 lookups per day and throttle it upto 5 requests/s. Because my dataset was so small, I was able to test my code and run two full runs within a single day's limit. Key generation is simple, but I found navigating the Google developer site kind of non-intuitive - you can find the information about generating and using your keys here.

I wanted to see the coverage of these 500 US addresses, so I wrote a small R script to do this (with lots of help from this page). Here is the R script and the output. As you can see, there is a nice concentration of addresses around the New York/New Jersey/Washington DC area, which is what we will use for our testing.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
# Source: viz.R
library(maps)

png("addr_dist.png")

map("state", interior=F)
map("state", boundary=F, col="gray", add=T)

data = read.csv("us-cities-annotated.csv", sep="\t", 
                col.names=c("fname", "lname", "company", 
                            "address", "city", "state", 
                            "zip", "lat", "lon"))
data.clean = data[!(data$lat==0 & data$lon==0), ]
points(data.clean$lon, data.clean$lat, pch=19, col="red", cex=0.5)

dev.off()


The indexing code is fairly straightforward, we leverage Solr's dynamic fields to set most of our address fields into either text or string fields. The latitude longitude pair we retrieved from the Geocoding API go into a special field type called LatLonType as a comma-separated pair.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// Source: src/main/scala/com/mycompany/solr4extras/geo/LatLonIndexer.scala
package com.mycompany.solr4extras.geo

import java.io.File
import java.util.concurrent.atomic.AtomicInteger

import scala.io.Source

import org.apache.solr.client.solrj.impl.HttpSolrServer
import org.apache.solr.common.SolrInputDocument

class LatLonIndexer {

  val solrUrl = "http://localhost:8983/solr/collection1"
  val infile = new File("src/main/resources/us-cities-annotated.csv")
  
  def buildIndex(): Unit = {
    val solr = new HttpSolrServer(solrUrl)
    val ctr = new AtomicInteger(0)
    Source.fromFile(infile).getLines()
      .foreach(line => {
        val doc = new SolrInputDocument()
        val cols = line.split("\t")
        doc.addField("id", ctr.addAndGet(1))
        doc.addField("firstname_s", cols(0))
        doc.addField("lastname_s", cols(1))
        doc.addField("company_t", cols(2))
        doc.addField("street_t", cols(3))
        doc.addField("city_s", cols(4))
        doc.addField("state_s", cols(5))
        doc.addField("zip_s", cols(6))
        doc.addField("latlon_p", 
          List(cols(7), cols(8)).mkString(","))
        solr.add(doc)
    })
    solr.commit()
    solr.shutdown()
  }
}

Finally, we now leverage Solr's Spatial capabilities as described on its wiki page. Here is the code for the searcher.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
// Source: src/main/scala/com/mycompany/solr4extras/geo/LatLonSearcher.scala
package com.mycompany.solr4extras.geo

import org.apache.solr.client.solrj.SolrQuery
import org.apache.solr.client.solrj.impl.HttpSolrServer
import org.apache.solr.common.SolrDocument
import scala.collection.JavaConversions._

class LatLonSearcher {

  val solr = new HttpSolrServer("http://localhost:8983/solr/collection1")
  
  def findWithinByGeofilt(p: Point, dkm: Double,
      sort: Boolean, nearestFirst: Boolean): List[LatLonDoc] = 
    findWithin("geofilt", p, dkm, sort, nearestFirst)
  
  def findWithinByBbox(p: Point, dkm: Double,
      sort: Boolean, nearestFirst: Boolean):List[LatLonDoc] =
    findWithin("bbox", p, dkm, sort, nearestFirst)
  
  def findWithin(method: String, p: Point, dkm: Double,
      sort: Boolean, nearestFirst: Boolean):
      List[LatLonDoc] = {
    val query = new SolrQuery()
    query.setQuery("*:*")
    query.setFields("*")
    query.setFilterQueries("{!%s}".format(method))
    query.set("pt", "%.2f,%.2f".format(p.x, p.y))
    query.set("d", dkm.toString)
    query.set("sfield", "latlon_p")
    if (sort) {
      query.set("sort", "geodist() %s"
        .format(if (nearestFirst) "asc" else "desc"))
      query.setFields("*,_dist_:geodist()")
    }
    val resp = solr.query(query)
    resp.getResults()
        .map(doc => getLatLonDocument(doc))
        .toList
  }

  def getLatLonDocument(sdoc: SolrDocument): LatLonDoc = {
    val latlon = sdoc.getFieldValue("latlon_p")
                     .asInstanceOf[String]
                     .split(",")
                     .map(_.toDouble)
    val dist = if (sdoc.getFieldValue("_dist_") != null) 
      sdoc.getFieldValue("_dist_").asInstanceOf[Double]
      else 0.0D 
    LatLonDoc(sdoc.getFieldValue("firstname_s").asInstanceOf[String],
      sdoc.getFieldValue("lastname_s").asInstanceOf[String],
      sdoc.getFieldValue("company_t").asInstanceOf[String],
      sdoc.getFieldValue("street_t").asInstanceOf[String],
      sdoc.getFieldValue("city_s").asInstanceOf[String],
      sdoc.getFieldValue("state_s").asInstanceOf[String],
      sdoc.getFieldValue("zip_s").asInstanceOf[String],
      Point(latlon(0), latlon(1)), dist)
  }
}

case class Point(x: Double, y: Double)
case class LatLonDoc(fname: String, lname: String, 
                     company: String, street: String, 
                     city: String, state: String, 
                     zip: String, location: Point,
                     dist: Double)

The main method the searcher code above exposes is the findWithin() method. You can specify one of two methods for finding LatLon points within a certain distance d (in kilometers) from a Point p. The two methods are geofilt and bbox - geofilt finds points within a circle and bbox finds points within a square (bounding box). The bbox method is slightly looser than the geofilt method (ie may return points farther away than d km from p), but is less heavier performance-wise. The findWithin() method also supports sorting by distance, using the geodist() function. Internally, these are implemented as function queries.

Of course, provider search is generally more than just searching by LatLon. The query as implemented in the code above will most likely only be triggered from the value of the entered zipcode, where each zipcode is mapped to a central LatLon point within it using data like this. In reality, provider search would include searching by provider's name, specialties, languages spoken, etc. The code above puts the LatLon search into the filter query (fq) portion, so we could add in queries for other parts of the lookup into either the main query (q) or additional clauses in the fq. There are other use cases one could explore, such as faceting by distance (see the wiki page above).

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
// Source: src/test/scala/com/mycompany/solr4extras/geo/LatLonSearcherTest.scala
package com.mycompany.solr4extras.geo

import org.junit.Test
import org.junit.Assert

class LatLonSearcherTest {

  val dcLocation = Point(38.89, -77.04)

  val searcher = new LatLonSearcher()
  
  @Test
  def testFindWithinByGeofilt(): Unit = {
    val results = searcher.findWithin(
      "geofilt", dcLocation, 50, false, false)
    val neighborStates = results.map(_.state).toSet
    Assert.assertEquals(9, results.size)
    Assert.assertEquals(2, neighborStates.size)
    Assert.assertTrue(neighborStates.contains("MD") &&
      neighborStates.contains("VA"))
  }

  @Test
  def testFindWithinByBbox(): Unit = {
    val results = searcher.findWithin(
      "bbox", dcLocation, 50, false, false)
    val neighborStates = results.map(_.state).toSet
    Assert.assertEquals(10, results.size)
    Assert.assertEquals(2, neighborStates.size)
    Assert.assertTrue(neighborStates.contains("MD") &&
      neighborStates.contains("VA"))
  }

  @Test
  def testSortByDistance(): Unit = {
    val results = searcher.findWithin(
      "bbox", dcLocation, 50, true, true)
    Assert.assertTrue(results.head.dist < results.last.dist)
    val formattedResults = results
      .map(result => "%s, %s %s %s (%.2f km)"
      .format(result.street, result.city, result.state, 
              result.zip, result.dist))
      .foreach(Console.println(_))
  }
}

The code above shows the test case we use to exercise our LatLonSearcher code. The first and second tests show how to call the findWithin() method with the "geofilt" and "bbox" methods. No sorting is performed on the distance between the query point (dcLocation, a point within the Washington DC metro area) and the target addresses. The first test returns 9 results and the second test returns 10, lending credence to the assertion that bbox is looser than geofilt. We also see that the results have addresses in Virginia and Maryland, two states that border Washington DC. The third test illustrates calling findWithin() using the bbox method and with results sorted by distance (closest first). Here are the results of this test - as you can see, there are 10 results, the last of which is outside the 50km radius (because of bbox).

1
 2
 3
 4
 5
 6
 7
 8
 9
10
64 5th Ave #1153, Mc Lean VA 22102 (5.79 km)
9506 Edgemore Ave, Bladensburg MD 20710 (12.45 km)
5 Cabot Rd, Mc Lean VA 22102 (12.84 km)
94 Chase Rd, Hyattsville MD 20785 (14.28 km)
747 Leonis Blvd, Annandale VA 22003 (14.73 km)
47857 Coney Island Ave, Clinton MD 20735 (20.56 km)
48 Lenox St, Fairfax VA 22030 (21.70 km)
3387 Ryan Dr, Hanover MD 21076 (43.49 km)
2853 S Central Expy, Glen Burnie MD 21061 (46.93 km)
2 W Scyene Rd #3, Baltimore MD 21217 (57.92 km)

And this is all I have for today. For those of you who have worked with Solr Spatial before, this post is probably going to be pretty basic, but for those of you who haven't, I hope that this gives a quick high level overview of what you can do with it. If you need it, the (Scala) code for this post is available on my project on GitHub.