Advanced

Spark ML Best Practices

Optimize your Spark ML workflows for performance, reliability, and production readiness.

Performance Tuning

Python — Key Configuration Settings
spark = SparkSession.builder \
    .appName("OptimizedML") \
    .config("spark.sql.shuffle.partitions", "200") \
    .config("spark.default.parallelism", "200") \
    .config("spark.driver.memory", "8g") \
    .config("spark.executor.memory", "16g") \
    .config("spark.executor.cores", "4") \
    .config("spark.memory.fraction", "0.8") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .getOrCreate()

Data Partitioning

  • Rule of thumb: 2-4 partitions per CPU core in your cluster.
  • Too few partitions: Underutilizes cluster resources, potential OOM errors.
  • Too many partitions: Excessive scheduling overhead, slow shuffle operations.
  • Use repartition/coalesce: df.repartition(200) for increasing, df.coalesce(50) for decreasing partitions.
Python — Smart Partitioning for ML
# Before training, repartition to match executor count
train_df = train_df.repartition(200)

# Cache frequently accessed DataFrames
train_df.cache()
train_df.count()  # Trigger caching

# After training, coalesce predictions for efficient writes
predictions.coalesce(10).write.parquet("predictions/")

Memory Management

Python — Caching Strategy
from pyspark import StorageLevel

# Cache training data in memory
train_df.cache()

# For large datasets, use memory + disk
train_df.persist(StorageLevel.MEMORY_AND_DISK)

# Unpersist when done
train_df.unpersist()

# Monitor cache usage
spark.sparkContext._jsc.sc().getExecutorMemoryStatus()

Production Checklist

  • Use Parquet: Store all data in Parquet format for optimal read performance.
  • Enable AQE: Adaptive Query Execution automatically optimizes shuffle partitions.
  • Version your models: Save pipeline models with version tags and track with MLflow.
  • Monitor with Spark UI: Check the Spark UI for skewed tasks, GC pauses, and shuffle spills.
  • Test locally first: Develop with local[*] on a sample, then scale to the cluster.
  • Handle data skew: Use salting or broadcast joins for skewed join keys.

Common Pitfalls

  • Calling collect() on large data: Brings all data to the driver, causing OOM. Use take() or write to storage instead.
  • Not caching training data: Without caching, data is re-read from disk for each iteration of training.
  • UDFs for simple operations: Python UDFs serialize data between JVM and Python. Use built-in Spark SQL functions whenever possible.
  • Ignoring data skew: One partition with 10x more data than others creates a bottleneck. Monitor partition sizes.
  • Small file problem: Too many small input files create excessive task overhead. Coalesce or use Delta Lake.

Frequently Asked Questions

Yes. Use Spark for data preparation and feature engineering, then train deep learning models with frameworks like PyTorch or TensorFlow. Libraries like Horovod and spark-tensorflow-distributor enable distributed deep learning on Spark clusters. For inference, use pandas UDFs to apply models in parallel.

Use scikit-learn when your data fits in memory on a single machine (typically under 10GB). Use Spark ML when data is too large for one machine, when you need distributed training, or when your data already lives in a Spark ecosystem (HDFS, Delta Lake, etc.).

For batch predictions, use Spark itself to transform new data with the saved PipelineModel. For real-time serving, export the model to ONNX or PMML format, or use MLflow to serve it as a REST API. You can also extract model weights and use them in a lightweight serving framework.