diff options
Diffstat (limited to 'co-purchase-analysis/src/main/scala')
| -rw-r--r-- | co-purchase-analysis/src/main/scala/Main.scala | 38 | 
1 files changed, 18 insertions, 20 deletions
diff --git a/co-purchase-analysis/src/main/scala/Main.scala b/co-purchase-analysis/src/main/scala/Main.scala index d3679b2..a711136 100644 --- a/co-purchase-analysis/src/main/scala/Main.scala +++ b/co-purchase-analysis/src/main/scala/Main.scala @@ -1,7 +1,6 @@  import org.apache.spark.sql.SparkSession  import org.apache.spark.rdd.RDD  import org.apache.spark.HashPartitioner -import scala.util.Try  /** A Spark application that analyzes co-purchased products.    * @@ -15,8 +14,8 @@ import scala.util.Try    *    * @example    *   {{{ -  * // Run the application with input path, output path and number of partitions -  * spark-submit co-purchase-analysis.jar input.csv output/ 50 +  * // Run the application with input path and output path +  * spark-submit co-purchase-analysis.jar input.csv output/    *   }}}    */  object CoPurchaseAnalysis { @@ -42,18 +41,14 @@ object CoPurchaseAnalysis {    /** Validates command line arguments and checks file existence.      *      * @param args -    *   Command line arguments array containing input file path, output -    *   directory path and partitions number +    *   Command line arguments array containing input file path and output +    *   directory path      * @return      *   Some(errorMessage) if validation fails, None if validation succeeds      */    def checkArguments(args: Array[String]): Option[String] = { -    if (args.length != 3) { -      return Some("Need params: <inputPath> <outputFolder> <numPartitions>") -    } - -    if (Try(args(2).toInt).isFailure) { -      return Some(s"'${args(2)}' is not a valid integer") +    if (args.length != 2) { +      return Some("Need params: <inputPath> <outputFolder>")      }      return None @@ -63,15 +58,15 @@ object CoPurchaseAnalysis {      *      * @param appName      *   The name of the Spark application -    * @param master -    *   The Spark master URL (e.g., "local", "yarn")      * @return      *   Configured SparkSession instance      */ -  def createSparkSession(appName: String, master: String): SparkSession = { +  def createSparkSession(appName: String): SparkSession = {      var session = SparkSession.builder        .appName(appName) -      .config("spark.master", master) +      .config("spark.executor.memory", "6g") +      .config("spark.executor.cores", "4") +      .config("spark.driver.memory", "4g")      val creds = System.getenv("GOOGLE_APPLICATION_CREDENTIALS")      if (creds != null) { @@ -155,14 +150,12 @@ object CoPurchaseAnalysis {      // Configuration values should be passed as parameters      val config = Map(        "appName" -> "Co-Purchase Analysis", -      "master" -> "local[*]",        "inputPath" -> args(0), -      "outputPath" -> args(1), -      "partitionsNumber" -> args(2) +      "outputPath" -> args(1)      )      // Program execution composed of pure functions -    val spark = createSparkSession(config("appName"), config("master")) +    val spark = createSparkSession(config("appName"))      try {        spark.sparkContext.setLogLevel("ERROR") @@ -170,7 +163,12 @@ object CoPurchaseAnalysis {          .textFile(config("inputPath"))          .map(parseLine) -      val result = processData(inputRDD, config("partitionsNumber").toInt) +      val cores = spark.conf.get("spark.executor.cores", "4").toInt +      val nodes = spark.conf.get("spark.executor.instance", "4").toInt +      val partitinosNumber = +        math.max(cores * nodes * 2, spark.sparkContext.defaultParallelism * 2) + +      val result = processData(inputRDD, partitinosNumber)          .saveAsTextFile(config("outputPath"))      } finally {        spark.stop()  |