From 80930bb9d945b2ffee0fdda78ebd8cbe1caa4dc2 Mon Sep 17 00:00:00 2001
From: Santo Cariotti <santo@dcariotti.me>
Date: Mon, 13 Jan 2025 19:08:36 +0100
Subject: Partitions number as argument

---
 co-purchase-analysis/src/main/scala/Main.scala | 35 +++++++++++++++++---------
 1 file changed, 23 insertions(+), 12 deletions(-)

(limited to 'co-purchase-analysis/src/main/scala/Main.scala')

diff --git a/co-purchase-analysis/src/main/scala/Main.scala b/co-purchase-analysis/src/main/scala/Main.scala
index 8fd82cb..e4938d8 100644
--- a/co-purchase-analysis/src/main/scala/Main.scala
+++ b/co-purchase-analysis/src/main/scala/Main.scala
@@ -1,6 +1,7 @@
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.rdd.RDD
 import org.apache.spark.HashPartitioner
+import scala.util.Try
 
 /** A Spark application that analyzes co-purchased products.
   *
@@ -14,8 +15,8 @@ import org.apache.spark.HashPartitioner
   *
   * @example
   *   {{{
-  * // Run the application with input and output paths
-  * spark-submit co-purchase-analysis.jar input.csv output/
+  * // Run the application with input path, output path and number of partitions
+  * spark-submit co-purchase-analysis.jar input.csv output/ 50
   *   }}}
   */
 object CoPurchaseAnalysis {
@@ -41,17 +42,21 @@ object CoPurchaseAnalysis {
   /** Validates command line arguments and checks file existence.
     *
     * @param args
-    *   Command line arguments array containing input file path and output
-    *   directory path
+    *   Command line arguments array containing input file path, output
+    *   directory path and partitions number
     * @return
     *   Some(errorMessage) if validation fails, None if validation succeeds
     */
   def checkArguments(args: Array[String]): Option[String] = {
-    if (args.length != 2) {
-      Some("You must define input file and output folder.")
-    } else {
-      None
+    if (args.length != 3) {
+      return Some("Need params: <inputPath> <outputFolder> <numPartitions>")
     }
+
+    if (Try(args(2).toInt).isFailure) {
+      return Some(s"'${args(2)}' is not a valid integer")
+    }
+
+    return None
   }
 
   /** Creates and configures a SparkSession.
@@ -102,10 +107,15 @@ object CoPurchaseAnalysis {
     *
     * @param data
     *   RDD containing OrderProduct instances
+    * @param partitionsNumber
+    *   Number of partitions used by HashPartitioner
     * @return
     *   RDD containing CoPurchase instances with purchase frequency counts
     */
-  def processData(data: RDD[OrderProduct]): RDD[String] = {
+  def processData(
+      data: RDD[OrderProduct],
+      partitionsNumber: Int
+  ): RDD[String] = {
     val pairs = data
       .map(order => (order.orderId, order.productId))
       .groupByKey()
@@ -116,7 +126,7 @@ object CoPurchaseAnalysis {
           y <- products if x < y
         } yield (ProductPair(x, y), 1)
       }
-      .partitionBy(new HashPartitioner(50))
+      .partitionBy(new HashPartitioner(partitionsNumber))
 
     val coProducts = pairs.reduceByKey(_ + _)
 
@@ -143,7 +153,8 @@ object CoPurchaseAnalysis {
       "appName" -> "Co-Purchase Analysis",
       "master" -> "local[*]",
       "inputPath" -> args(0),
-      "outputPath" -> args(1)
+      "outputPath" -> args(1),
+      "partitionsNumber" -> args(2)
     )
 
     // Program execution composed of pure functions
@@ -155,7 +166,7 @@ object CoPurchaseAnalysis {
         .textFile(config("inputPath"))
         .map(parseLine)
 
-      val result = processData(inputRDD)
+      val result = processData(inputRDD, config("partitionsNumber").toInt)
         .saveAsTextFile(config("outputPath"))
     } finally {
       spark.stop()
-- 
cgit v1.2.3-18-g5258