ALSModel — Model for Predictions

ALSModel is a model fitted by ALS algorithm.

Note
A Model in Spark MLlib is a Transformer that comes with a custom transform method.

When making prediction (i.e. executed), ALSModel…​FIXME

ALSModel is created when:

  • ALS fits a ALSModel

  • ALSModel copies a ALSModel

  • ALSModelReader loads a ALSModel from a persistent storage

ALSModel is a MLWritable.

// The following spark-shell session is used to show
// how ALSModel works under the covers
// Mostly to learn how to work with the private ALSModel class

// Use paste raw mode to copy the code
// :paste -raw (or its shorter version :pa -raw)
// BEGIN :pa -raw
package org.apache.spark.ml

import org.apache.spark.sql._
class MyALS(spark: SparkSession) {
  import spark.implicits._
  val userFactors = Seq((0, Seq(0.3, 0.2))).toDF("id", "features")
  val itemFactors = Seq((0, Seq(0.3, 0.2))).toDF("id", "features")
  import org.apache.spark.ml.recommendation._
  val alsModel = new ALSModel(uid = "uid", rank = 10, userFactors, itemFactors)
}
// END :pa -raw

// Copy the following to spark-shell directly
import org.apache.spark.ml._
val model = new MyALS(spark).
  alsModel.
  setUserCol("user").
  setItemCol("item")

import org.apache.spark.sql.types._
val mySchema = new StructType().
  add($"user".float).
  add($"item".float)

val transformedSchema = model.transformSchema(mySchema)
scala> transformedSchema.printTreeString
root
 |-- user: float (nullable = true)
 |-- item: float (nullable = true)
 |-- prediction: float (nullable = false)

Making Predictions — transform Method

transform(dataset: Dataset[_]): DataFrame
Note
transform is a part of Transformer Contract.

Internally, transform validates the schema of the dataset.

transform left-joins the dataset with userFactors dataset (using userCol column of dataset and id column of userFactors).

Note

Left join takes two datasets and gives all the rows from the left side (of the join) combined with the corresponding row from the right side if available or null.

val rows0 = spark.range(0)
val rows5 = spark.range(5)
scala> rows0.join(rows5, Seq("id"), "left").show
+---+
| id|
+---+
+---+

scala> rows5.join(rows0, Seq("id"), "left").count
res3: Long = 5

scala> spark.range(0, 55).join(spark.range(56, 200), Seq("id"), "left").count
res4: Long = 55

val rows02 = spark.range(0, 2)
val rows39 = spark.range(3, 9)
scala> rows02.join(rows39, Seq("id"), "left").show
+---+
| id|
+---+
|  0|
|  1|
+---+

val names = Seq((3, "three"), (4, "four")).toDF("id", "name")
scala> rows02.join(names, Seq("id"), "left").show
+---+----+
| id|name|
+---+----+
|  0|null|
|  1|null|
+---+----+

transform left-joins the dataset with itemFactors dataset (using itemCol column of dataset and id column of itemFactors).

transform makes predictions using the features columns of userFactors and itemFactors datasets (per every row in the left-joined dataset).

transform takes (selects) all the columns from the dataset and predictionCol with predictions.

Ultimately, transform drops rows containing null or NaN values for predictions if coldStartStrategy is drop.

Note
The default value of coldStartStrategy is nan that does not drop missing values from predictions column.

transformSchema Method

transformSchema(schema: StructType): StructType
Note
transformSchema is a part of Transformer Contract.

Internally, transform validates the schema of the dataset.

Creating ALSModel Instance

ALSModel takes the following when created:

  • Unique ID

  • Rank

  • DataFrame of user factors

  • DataFrame of item factors

ALSModel initializes the internal registries and counters.

Requesting sdot from BLAS — predict Internal Property

predict: UserDefinedFunction

predict is a user-defined function (UDF) that takes two collections of float numbers and requests BLAS for sdot.

Caution
FIXME Read about com.github.fommil.netlib.BLAS.getInstance.sdot.
Note
predict is a mere wrapper of com.github.fommil.netlib.BLAS.
Note
predict is used exclusively when ALSModel is requested to transform.

Creating ALSModel with Extra Parameters — copy Method

copy(extra: ParamMap): ALSModel
Note
copy is a part of Model Contract.

copy creates a new ALSModel.

copy then copies extra parameters to the new ALSModel and sets the parent.

results matching ""

    No results matching ""