跳转至

Syntax cheat sheet(语法速查表)

A quick reference guide to the most commonly used patterns and functions in Java Spark SQL.

See the Java Spark Official Documentation ↗ for additional information on Java Spark.

Common Patterns

Logging Output

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MyClass {
    private static final Logger LOG = LoggerFactory.getLogger(MyClass.class);
    LOG.info("example log output");
}

Importing Datasets, Rows, Functions, and Types

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.types.DataTypes;

Filtering

// Filter on equals condition
df = df.filter(col("is_adult").equalTo("Y"));

// Filter on >, <, >=, <= condition
df = df.filter(col("age").gt(25));
df = df.filter(col("age").lt(25));
df = df.filter(col("age").geq(25));
df = df.filter(col("age").leq(25));

// Multiple conditions
df = df.filter(col("age").gt(25).and(col("is_adult").equalTo("Y")));

// Compare against a list of allowed values
df = df.filter(col("first_name").isin(List.of(3, 4, 7)));

// Sort results
df = df.orderBy(col("age").asc());
df = df.orderBy(col("age").desc());

Joins

// Left join in another dataset
df = df.join(personLookupTable, col("person_id"), "left");

// Left anti-join in another dataset (return unmatched rows in left dataframe)
df = df.join(personLookupTable, col("person_id"), "leftanti");

// Match on different columns in left & right datasets
df = df.join(otherTable, col("id").equalTo(col("person_id")), "left");

// Match on multiple columns
df = df.join(otherTable, col("first_name").equalTo(col("name")).and(
             col("last_name").equalTo(col("family_name"))), "left");

// Useful for one-liner lookup code joins
public Dataset<Row> lookupAndReplace(Dataset<Row> df1, Dataset<Row> df2,
    String df1Key, String df2Key, String df2Value) {

    return df1.join(df2.select(col(df2Key), col(df2Value)),
                    col(df1Key).equalTo(col(df2Key)), "left")
              .withColumn(df1Key, coalesce(col(df2Value), col(df1Key)))
              .drop(df2Key, df2Value);
}

Dataset<Row> df = lookupAndReplace(people, payCodes, id, payCodeId, payCodeDesc);

Column Operations

// Add a new static column
df = df.withColumn("status", lit("PASS"));

// Construct a new dynamic column
df = df.withColumn("full_name", when(
    (col("fname").isNotNull().or(col("lname").isNotNull())),
    concat_ws(" ", col("fname"), col("lname")))
    .otherwise(lit(null)));

// Pick which columns to keep, optionally rename some
df = df.select(
    col("name"),
    col("age"),
    col("dob").alias("date_of_birth")
);

// Remove columns
df = df.drop("mod_dt", "mod_username");

// Rename a column
df = df.withColumnRenamed("dob", "date_of_birth");

// Keep all the columns which also occur in another dataset
import org.apache.spark.sql.Column;
import scala.collection.JavaConversions;

List<Column> columnsToSelect = new ArrayList<Column>();
List<String> df2Columns = Arrays.asList(df2.columns());
for (String c : df.columns()) {
  if (df2Columns.contains(c)) {
    columnsToSelect.add(col(c));
  }
}
df = df.select(JavaConversions.asScalaBuffer(columnsToSelect));

// Batch Rename/Clean Columns
for (String c in df.columns()) {
  df = df.withColumnRenamed(c, c.toLowerCase().replace(" ", "_").replace("-", "_"));
}

Casting & Coalescing Null Values & Duplicates

import org.apache.spark.sql.types.DataTypes;

// Cast a column to a different type (both ways equivalent)
df = df.withColumn("price", col("price").cast(DataTypes.DoubleType));
df = df.withColumn("price", col("price").cast("double"));

// Create an all null column and cast to specific type
df = df.withColumn("name", lit(null).cast(DataTypes.StringType));
df = df.withColumn("name", lit(null).cast("string"));

// Replace all nulls with a specific value
df = df.na().fill(ImmutableMap.of("first_name", "Tom", "age", 0));

// Take the first value that is not null
df = df.withColumn("last_name", coalesce(col("last_name"), col("surname"), lit("N/A")));

// Drop duplicate rows in a dataset (same as distinct())
df = df.dropDuplicates()

// Drop duplicate rows, but consider only specific columns
df = df.dropDuplicates("name", "height");

String Operations

String Filters

// Contains - col.contains(string)
df = df.filter(col("name").contains("o"));

//Starts With - col.startsWith(string)
df = df.filter(col("name").startsWith("Al"));

// Ends With - col.endsWith(string)
df = df.filter(col("name").endsWith("ice"));

// Is Null - col.isNull()
df = df.filter(col("is_adult").isNull());

// Is Not Null - col.isNotNull()
df = df.filter(col("first_name").isNotNull());

// Like - col.like(string_with_sql_wildcards)
df = df.filter(col("name").like("Al%"));

// Regex Like - col.rlike(regex)
df = df.filter(col("name").rlike("[A-Z]*ice$"));

// Is In List - col.isin(Object... values)
df = df.filter(col("name").isin("Bob", "Mike"));

String Functions

// Substring - col.substr(startPos, length) (1-based indexing)
df = df.withColumn("short_id", col("id").substr(1, 10));

// Trim - trim(col)
df = df.withColumn("name", trim(col("name")));

// Left Pad - lpad(col, len, pad)
// Right Pad - rpad(col, len, pad)
df = df.withColumn("id", lpad(col("id"), 4, "0"));

// Left Trim - ltrim(col)
// Right Trim - rtrim(col)
df = df.withColumn("id", ltrim(col("id")));

// Concatenate - concat(Column... cols) (null if any column null)
df = df.withColumn("full_name", concat(col("fname"), lit(" "), col("lname")));

// Concatenate with Delimiter - concat_ws(delim, Column... cols) (ignores nulls)
df = df.withColumn("full_name", concat_ws("-", "fname", "lname"));

// Regex Replace - regexp_replace(col, pattern, replacement)
df = df.withColumn("id", regexp_replace(col("id"), "0F1(.*)", "1F1-$1"));

// Regex Extract - regexp_extract(str, pattern, idx)
df = df.withColumn("id", regexp_extract(col("id"), "[0-9]*", 0));

Number Operations

// Round - round(col, scale=0)
df = df.withColumn("price", round(col("price"), 0));

// Floor - floor(col)
df = df.withColumn("price", floor(col("price")));

// Ceiling - ceil(col)
df = df.withColumn("price", ceil(col("price")));

// Absolute Value - abs(col)
df = df.withColumn("price", abs(col("price")));

// X raised to power Y – pow(X, Y)
df = df.withColumn("exponential_growth", pow(col("x"), 2.0));

// Select smallest value out of multiple columns – least(Column... cols)
df = df.withColumn("least", least(col("subtotal"), col("total")));

// Select largest value out of multiple columns – greatest(Column... cols)
df = df.withColumn("greatest", greatest(col("subtotal"), col("total")));

Date & Timestamp Operations

// Convert a string of known format to a date (excludes time information)
df = df.withColumn("date_of_birth", to_date(col("date_of_birth"), "yyyy-MM-dd"));

// Convert a string of known format to a timestamp (includes time information)
df = df.withColumn("time_of_birth", to_timestamp(col("time_of_birth"), "yyyy-MM-dd HH:mm:ss"));

// Get year from date:       year(col)
// Get month from date:      month(col)
// Get day from date:        dayofmonth(col)
// Get hour from date:       hour(col)
// Get minute from date:     minute(col)
// Get second from date:     second(col)
df = df.filter(year(col("date_of_birth")).equalTo("2017"));

// Add & subtract days
df = df.withColumn("three_days_after", date_add(col("date_of_birth"), 3));
df = df.withColumn("three_days_before", date_sub(col("date_of_birth"), 3));

// Add & subtract months
df = df.withColumn("next_month", add_months(col("date_of_birth"), 1));
df = df.withColumn("previous_month", add_months(col("date_of_birth"), -1));

// Get number of days between two dates
df = df.withColumn("days_between", datediff(col("end"), col("start")));

// Get number of months between two dates
df = df.withColumn("months_between", months_between(col("end"), col("start")));

// Keep only rows where date_of_birth is between 2017-05-10 and 2018-07-21
df = df.filter(
    (col("date_of_birth").geq("2017-05-10")).and(
    (col("date_of_birth").leq("2018-07-21")))
);

Array and Struct Operations

// Array or Struct column from existing columns
df = df.withColumn("guardians", array(col("guardian_1"), col("guardian_2")));
df = df.withColumn("properties", struct(col("hair_color"), col("eye_color")));

// Extract from Array or Struct column by index or key (null if invalid)
df = df.withColumn("primary_guardian", col("guardians").getItem(0));
df = df.withColumn("hair_color", col("properties").getItem("hair_color"));

// Explode Array or Struct column into multiple rows
df = df.select(col("child_name"), explode(col("guardians")));
df = df.select(col("child_name"), explode(col("properties")));

// Explode Struct column into multiple columns
df = df.select(col("child_name"), col("properties.*"));

Aggregation Operations

// Row Count:                count(col), countDistinct(col)
// Sum of Rows in Group:     sum(col)
// Mean of Rows in Group:    mean(col)
// Max of Rows in Group:     max(col)
// Min of Rows in Group:     min(col)
// First Row in Group:       first(col, ignoreNulls)

df = df.groupBy(col("address")).agg(
  count(col("uuid")).alias("num_residents"),
  max(col("age")).alias("oldest_age"),
  first(col("city"), true).alias("city")
);

// Collect a Set of all Rows in Group:       collect_set(col)
// Collect a List of all Rows in Group:      collect_list(col)
df = df.groupBy(col("address")).agg(collect_set("name").alias("resident_names"));

Advanced Operations

Repartitioning

// Repartition – df.repartition(num_output_partitions)
df = df.repartition(1);

UDFs (User Defined Functions)

import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.types.DecimalType;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.udf;
import java.math.BigDecimal;

/**
 * The following example creates a java UDF which adds two columns of Spark type "Decimal".
 * Spark represents DecimalType with java.math.BigDecimal, hence why we use this class.
 * Other Spark types will be represented by other Java types as defined in https://spark.apache.org/docs/3.0.0/sql-ref-datatypes.html
 */

public final class HighLevelAutoTransform {
    @Compute
    @Output("...")
    public Dataset<Row> myComputeFunction(@Input("...") Dataset<Row> df) {
        UserDefinedFunction addsUDF = udf((BigDecimal i, BigDecimal j) -> {
            if (i == null || j == null) { // always handle null cases
                return null;
            }

            return i.add(j);
        }, new DecimalType()); // this is the Spark data type of the result of the UDF

        return df.withColumn("a_plus_b", addsUDF.apply(col("a"), col("b")));
    }
}

中文翻译


语法速查表

这是Java Spark SQL中最常用的模式与函数的快速参考指南。

更多Java Spark相关信息请参考Java Spark Official Documentation ↗

常用模式

日志输出

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MyClass {
    private static final Logger LOG = LoggerFactory.getLogger(MyClass.class);
    LOG.info("example log output");
}

导入数据集(Dataset)、行(Row)、函数与类型

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.types.DataTypes;

过滤

// Filter on equals condition
df = df.filter(col("is_adult").equalTo("Y"));

// Filter on >, <, >=, <= condition
df = df.filter(col("age").gt(25));
df = df.filter(col("age").lt(25));
df = df.filter(col("age").geq(25));
df = df.filter(col("age").leq(25));

// Multiple conditions
df = df.filter(col("age").gt(25).and(col("is_adult").equalTo("Y")));

// Compare against a list of allowed values
df = df.filter(col("first_name").isin(List.of(3, 4, 7)));

// Sort results
df = df.orderBy(col("age").asc());
df = df.orderBy(col("age").desc());

关联(Join)

// Left join in another dataset
df = df.join(personLookupTable, col("person_id"), "left");

// Left anti-join in another dataset (return unmatched rows in left dataframe)
df = df.join(personLookupTable, col("person_id"), "leftanti");

// Match on different columns in left & right datasets
df = df.join(otherTable, col("id").equalTo(col("person_id")), "left");

// Match on multiple columns
df = df.join(otherTable, col("first_name").equalTo(col("name")).and(
             col("last_name").equalTo(col("family_name"))), "left");

// Useful for one-liner lookup code joins
public Dataset<Row> lookupAndReplace(Dataset<Row> df1, Dataset<Row> df2,
    String df1Key, String df2Key, String df2Value) {

    return df1.join(df2.select(col(df2Key), col(df2Value)),
                    col(df1Key).equalTo(col(df2Key)), "left")
              .withColumn(df1Key, coalesce(col(df2Value), col(df1Key)))
              .drop(df2Key, df2Value);
}

Dataset<Row> df = lookupAndReplace(people, payCodes, id, payCodeId, payCodeDesc);

列操作

// Add a new static column
df = df.withColumn("status", lit("PASS"));

// Construct a new dynamic column
df = df.withColumn("full_name", when(
    (col("fname").isNotNull().or(col("lname").isNotNull())),
    concat_ws(" ", col("fname"), col("lname")))
    .otherwise(lit(null)));

// Pick which columns to keep, optionally rename some
df = df.select(
    col("name"),
    col("age"),
    col("dob").alias("date_of_birth")
);

// Remove columns
df = df.drop("mod_dt", "mod_username");

// Rename a column
df = df.withColumnRenamed("dob", "date_of_birth");

// Keep all the columns which also occur in another dataset
import org.apache.spark.sql.Column;
import scala.collection.JavaConversions;

List<Column> columnsToSelect = new ArrayList<Column>();
List<String> df2Columns = Arrays.asList(df2.columns());
for (String c : df.columns()) {
  if (df2Columns.contains(c)) {
    columnsToSelect.add(col(c));
  }
}
df = df.select(JavaConversions.asScalaBuffer(columnsToSelect));

// Batch Rename/Clean Columns
for (String c in df.columns()) {
  df = df.withColumnRenamed(c, c.toLowerCase().replace(" ", "_").replace("-", "_"));
}

类型转换(Cast)、空值合并(Coalesce)、空值(Null)处理与去重

import org.apache.spark.sql.types.DataTypes;

// Cast a column to a different type (both ways equivalent)
df = df.withColumn("price", col("price").cast(DataTypes.DoubleType));
df = df.withColumn("price", col("price").cast("double"));

// Create an all null column and cast to specific type
df = df.withColumn("name", lit(null).cast(DataTypes.StringType));
df = df.withColumn("name", lit(null).cast("string"));

// Replace all nulls with a specific value
df = df.na().fill(ImmutableMap.of("first_name", "Tom", "age", 0));

// Take the first value that is not null
df = df.withColumn("last_name", coalesce(col("last_name"), col("surname"), lit("N/A")));

// Drop duplicate rows in a dataset (same as distinct())
df = df.dropDuplicates()

// Drop duplicate rows, but consider only specific columns
df = df.dropDuplicates("name", "height");

字符串操作

字符串过滤

// Contains - col.contains(string)
df = df.filter(col("name").contains("o"));

//Starts With - col.startsWith(string)
df = df.filter(col("name").startsWith("Al"));

// Ends With - col.endsWith(string)
df = df.filter(col("name").endsWith("ice"));

// Is Null - col.isNull()
df = df.filter(col("is_adult").isNull());

// Is Not Null - col.isNotNull()
df = df.filter(col("first_name").isNotNull());

// Like - col.like(string_with_sql_wildcards)
df = df.filter(col("name").like("Al%"));

// Regex Like - col.rlike(regex)
df = df.filter(col("name").rlike("[A-Z]*ice$"));

// Is In List - col.isin(Object... values)
df = df.filter(col("name").isin("Bob", "Mike"));

字符串函数

// Substring - col.substr(startPos, length) (1-based indexing)
df = df.withColumn("short_id", col("id").substr(1, 10));

// Trim - trim(col)
df = df.withColumn("name", trim(col("name")));

// Left Pad - lpad(col, len, pad)
// Right Pad - rpad(col, len, pad)
df = df.withColumn("id", lpad(col("id"), 4, "0"));

// Left Trim - ltrim(col)
// Right Trim - rtrim(col)
df = df.withColumn("id", ltrim(col("id")));

// Concatenate - concat(Column... cols) (null if any column null)
df = df.withColumn("full_name", concat(col("fname"), lit(" "), col("lname")));

// Concatenate with Delimiter - concat_ws(delim, Column... cols) (ignores nulls)
df = df.withColumn("full_name", concat_ws("-", "fname", "lname"));

// Regex Replace - regexp_replace(col, pattern, replacement)
df = df.withColumn("id", regexp_replace(col("id"), "0F1(.*)", "1F1-$1"));

// Regex Extract - regexp_extract(str, pattern, idx)
df = df.withColumn("id", regexp_extract(col("id"), "[0-9]*", 0));

数值操作

// Round - round(col, scale=0)
df = df.withColumn("price", round(col("price"), 0));

// Floor - floor(col)
df = df.withColumn("price", floor(col("price")));

// Ceiling - ceil(col)
df = df.withColumn("price", ceil(col("price")));

// Absolute Value - abs(col)
df = df.withColumn("price", abs(col("price")));

// X raised to power Y – pow(X, Y)
df = df.withColumn("exponential_growth", pow(col("x"), 2.0));

// Select smallest value out of multiple columns – least(Column... cols)
df = df.withColumn("least", least(col("subtotal"), col("total")));

// Select largest value out of multiple columns – greatest(Column... cols)
df = df.withColumn("greatest", greatest(col("subtotal"), col("total")));

日期与时间戳(Timestamp)操作

// Convert a string of known format to a date (excludes time information)
df = df.withColumn("date_of_birth", to_date(col("date_of_birth"), "yyyy-MM-dd"));

// Convert a string of known format to a timestamp (includes time information)
df = df.withColumn("time_of_birth", to_timestamp(col("time_of_birth"), "yyyy-MM-dd HH:mm:ss"));

// Get year from date:       year(col)
// Get month from date:      month(col)
// Get day from date:        dayofmonth(col)
// Get hour from date:       hour(col)
// Get minute from date:     minute(col)
// Get second from date:     second(col)
df = df.filter(year(col("date_of_birth")).equalTo("2017"));

// Add & subtract days
df = df.withColumn("three_days_after", date_add(col("date_of_birth"), 3));
df = df.withColumn("three_days_before", date_sub(col("date_of_birth"), 3));

// Add & subtract months
df = df.withColumn("next_month", add_months(col("date_of_birth"), 1));
df = df.withColumn("previous_month", add_months(col("date_of_birth"), -1));

// Get number of days between two dates
df = df.withColumn("days_between", datediff(col("end"), col("start")));

// Get number of months between two dates
df = df.withColumn("months_between", months_between(col("end"), col("start")));

// Keep only rows where date_of_birth is between 2017-05-10 and 2018-07-21
df = df.filter(
    (col("date_of_birth").geq("2017-05-10")).and(
    (col("date_of_birth").leq("2018-07-21")))
);

数组(Array)与结构体(Struct)操作

// Array or Struct column from existing columns
df = df.withColumn("guardians", array(col("guardian_1"), col("guardian_2")));
df = df.withColumn("properties", struct(col("hair_color"), col("eye_color")));

// Extract from Array or Struct column by index or key (null if invalid)
df = df.withColumn("primary_guardian", col("guardians").getItem(0));
df = df.withColumn("hair_color", col("properties").getItem("hair_color"));

// Explode Array or Struct column into multiple rows
df = df.select(col("child_name"), explode(col("guardians")));
df = df.select(col("child_name"), explode(col("properties")));

// Explode Struct column into multiple columns
df = df.select(col("child_name"), col("properties.*"));

聚合(Aggregation)操作

// Row Count:                count(col), countDistinct(col)
// Sum of Rows in Group:     sum(col)
// Mean of Rows in Group:    mean(col)
// Max of Rows in Group:     max(col)
// Min of Rows in Group:     min(col)
// First Row in Group:       first(col, ignoreNulls)

df = df.groupBy(col("address")).agg(
  count(col("uuid")).alias("num_residents"),
  max(col("age")).alias("oldest_age"),
  first(col("city"), true).alias("city")
);

// Collect a Set of all Rows in Group:       collect_set(col)
// Collect a List of all Rows in Group:      collect_list(col)
df = df.groupBy(col("address")).agg(collect_set("name").alias("resident_names"));

高级操作

重分区(Repartition)

// Repartition – df.repartition(num_output_partitions)
df = df.repartition(1);

用户自定义函数(UDF, User Defined Functions)

import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.types.DecimalType;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.udf;
import java.math.BigDecimal;

/**
 * The following example creates a java UDF which adds two columns of Spark type "Decimal".
 * Spark represents DecimalType with java.math.BigDecimal, hence why we use this class.
 * Other Spark types will be represented by other Java types as defined in https://spark.apache.org/docs/3.0.0/sql-ref-datatypes.html
 */

public final class HighLevelAutoTransform {
    @Compute
    @Output("...")
    public Dataset<Row> myComputeFunction(@Input("...") Dataset<Row> df) {
        UserDefinedFunction addsUDF = udf((BigDecimal i, BigDecimal j) -> {
            if (i == null || j == null) { // always handle null cases
                return null;
            }

            return i.add(j);
        }, new DecimalType()); // this is the Spark data type of the result of the UDF

        return df.withColumn("a_plus_b", addsUDF.apply(col("a"), col("b")));
    }
}