diff options
Diffstat (limited to 'co-purchase-analysis/src/main/scala/Main.scala')
-rw-r--r-- | co-purchase-analysis/src/main/scala/Main.scala | 38 |
1 files changed, 18 insertions, 20 deletions
diff --git a/co-purchase-analysis/src/main/scala/Main.scala b/co-purchase-analysis/src/main/scala/Main.scala index d3679b2..a711136 100644 --- a/co-purchase-analysis/src/main/scala/Main.scala +++ b/co-purchase-analysis/src/main/scala/Main.scala @@ -1,7 +1,6 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.rdd.RDD import org.apache.spark.HashPartitioner -import scala.util.Try /** A Spark application that analyzes co-purchased products. * @@ -15,8 +14,8 @@ import scala.util.Try * * @example * {{{ - * // Run the application with input path, output path and number of partitions - * spark-submit co-purchase-analysis.jar input.csv output/ 50 + * // Run the application with input path and output path + * spark-submit co-purchase-analysis.jar input.csv output/ * }}} */ object CoPurchaseAnalysis { @@ -42,18 +41,14 @@ object CoPurchaseAnalysis { /** Validates command line arguments and checks file existence. * * @param args - * Command line arguments array containing input file path, output - * directory path and partitions number + * Command line arguments array containing input file path and output + * directory path * @return * Some(errorMessage) if validation fails, None if validation succeeds */ def checkArguments(args: Array[String]): Option[String] = { - if (args.length != 3) { - return Some("Need params: <inputPath> <outputFolder> <numPartitions>") - } - - if (Try(args(2).toInt).isFailure) { - return Some(s"'${args(2)}' is not a valid integer") + if (args.length != 2) { + return Some("Need params: <inputPath> <outputFolder>") } return None @@ -63,15 +58,15 @@ object CoPurchaseAnalysis { * * @param appName * The name of the Spark application - * @param master - * The Spark master URL (e.g., "local", "yarn") * @return * Configured SparkSession instance */ - def createSparkSession(appName: String, master: String): SparkSession = { + def createSparkSession(appName: String): SparkSession = { var session = SparkSession.builder .appName(appName) - .config("spark.master", master) + .config("spark.executor.memory", "6g") + .config("spark.executor.cores", "4") + .config("spark.driver.memory", "4g") val creds = System.getenv("GOOGLE_APPLICATION_CREDENTIALS") if (creds != null) { @@ -155,14 +150,12 @@ object CoPurchaseAnalysis { // Configuration values should be passed as parameters val config = Map( "appName" -> "Co-Purchase Analysis", - "master" -> "local[*]", "inputPath" -> args(0), - "outputPath" -> args(1), - "partitionsNumber" -> args(2) + "outputPath" -> args(1) ) // Program execution composed of pure functions - val spark = createSparkSession(config("appName"), config("master")) + val spark = createSparkSession(config("appName")) try { spark.sparkContext.setLogLevel("ERROR") @@ -170,7 +163,12 @@ object CoPurchaseAnalysis { .textFile(config("inputPath")) .map(parseLine) - val result = processData(inputRDD, config("partitionsNumber").toInt) + val cores = spark.conf.get("spark.executor.cores", "4").toInt + val nodes = spark.conf.get("spark.executor.instance", "4").toInt + val partitinosNumber = + math.max(cores * nodes * 2, spark.sparkContext.defaultParallelism * 2) + + val result = processData(inputRDD, partitinosNumber) .saveAsTextFile(config("outputPath")) } finally { spark.stop() |