From db051bc8ab89f0c5d1f43a3e891e844a942deaa5 Mon Sep 17 00:00:00 2001 From: Edgar Rodriguez Date: Fri, 25 Sep 2020 18:55:33 -0400 Subject: [PATCH] Spark: Follow name mapping when importing ORC tables (#1399) --- .../org/apache/iceberg/orc/OrcMetrics.java | 39 +- .../apache/iceberg/spark/SparkTableUtil.java | 8 +- .../spark/source/TestSparkTableUtil.java | 519 ++++++++++-------- 3 files changed, 319 insertions(+), 247 deletions(-) diff --git a/orc/src/main/java/org/apache/iceberg/orc/OrcMetrics.java b/orc/src/main/java/org/apache/iceberg/orc/OrcMetrics.java index c37731a0f..215fbc2ad 100644 --- a/orc/src/main/java/org/apache/iceberg/orc/OrcMetrics.java +++ b/orc/src/main/java/org/apache/iceberg/orc/OrcMetrics.java @@ -40,6 +40,7 @@ import org.apache.iceberg.expressions.Literal; import org.apache.iceberg.hadoop.HadoopInputFile; import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.mapping.NameMapping; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.types.Conversions; @@ -73,15 +74,19 @@ public static Metrics fromInputFile(InputFile file) { } public static Metrics fromInputFile(InputFile file, MetricsConfig metricsConfig) { + return fromInputFile(file, metricsConfig, null); + } + + public static Metrics fromInputFile(InputFile file, MetricsConfig metricsConfig, NameMapping mapping) { final Configuration config = (file instanceof HadoopInputFile) ? ((HadoopInputFile) file).getConf() : new Configuration(); - return fromInputFile(file, config, metricsConfig); + return fromInputFile(file, config, metricsConfig, mapping); } - static Metrics fromInputFile(InputFile file, Configuration config, MetricsConfig metricsConfig) { + static Metrics fromInputFile(InputFile file, Configuration config, MetricsConfig metricsConfig, NameMapping mapping) { try (Reader orcReader = ORC.newFileReader(file, config)) { return buildOrcMetrics(orcReader.getNumberOfRows(), orcReader.getSchema(), orcReader.getStatistics(), - metricsConfig); + metricsConfig, mapping); } catch (IOException ioe) { throw new RuntimeIOException(ioe, "Failed to open file: %s", file.location()); } @@ -89,27 +94,40 @@ static Metrics fromInputFile(InputFile file, Configuration config, MetricsConfig static Metrics fromWriter(Writer writer, MetricsConfig metricsConfig) { try { - return buildOrcMetrics(writer.getNumberOfRows(), writer.getSchema(), writer.getStatistics(), metricsConfig); + return buildOrcMetrics(writer.getNumberOfRows(), writer.getSchema(), writer.getStatistics(), metricsConfig, null); } catch (IOException ioe) { throw new RuntimeIOException(ioe, "Failed to get statistics from writer"); } } - private static Metrics buildOrcMetrics(long numOfRows, TypeDescription orcSchema, - ColumnStatistics[] colStats, MetricsConfig metricsConfig) { - final Schema schema = ORCSchemaUtil.convert(orcSchema); - final Set statsColumns = statsColumns(orcSchema); + private static Metrics buildOrcMetrics(final long numOfRows, final TypeDescription orcSchema, + final ColumnStatistics[] colStats, final MetricsConfig metricsConfig, + final NameMapping mapping) { + final TypeDescription orcSchemaWithIds = (!ORCSchemaUtil.hasIds(orcSchema) && mapping != null) ? + ORCSchemaUtil.applyNameMapping(orcSchema, mapping) : orcSchema; + final Set statsColumns = statsColumns(orcSchemaWithIds); final MetricsConfig effectiveMetricsConfig = Optional.ofNullable(metricsConfig) .orElseGet(MetricsConfig::getDefault); Map columnSizes = Maps.newHashMapWithExpectedSize(colStats.length); Map valueCounts = Maps.newHashMapWithExpectedSize(colStats.length); Map nullCounts = Maps.newHashMapWithExpectedSize(colStats.length); + + if (!ORCSchemaUtil.hasIds(orcSchemaWithIds)) { + return new Metrics(numOfRows, + columnSizes, + valueCounts, + nullCounts, + null, + null); + } + + final Schema schema = ORCSchemaUtil.convert(orcSchemaWithIds); Map lowerBounds = Maps.newHashMap(); Map upperBounds = Maps.newHashMap(); for (int i = 0; i < colStats.length; i++) { final ColumnStatistics colStat = colStats[i]; - final TypeDescription orcCol = orcSchema.findSubtype(i); + final TypeDescription orcCol = orcSchemaWithIds.findSubtype(i); final Optional icebergColOpt = ORCSchemaUtil.icebergID(orcCol) .map(schema::findField); @@ -261,7 +279,8 @@ private static class StatsColumnsVisitor extends OrcSchemaVisitor> public Set record(TypeDescription record, List names, List> fields) { ImmutableSet.Builder result = ImmutableSet.builder(); fields.stream().filter(Objects::nonNull).forEach(result::addAll); - record.getChildren().stream().map(ORCSchemaUtil::fieldId).forEach(result::add); + record.getChildren().stream().map(ORCSchemaUtil::icebergID).filter(Optional::isPresent) + .map(Optional::get).forEach(result::add); return result.build(); } } diff --git a/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java b/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java index 69e9cdca3..0a8b591e1 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java +++ b/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java @@ -322,8 +322,7 @@ public static List listPartition(Map partition, String } else if (format.contains("parquet")) { return listParquetPartition(partition, uri, spec, conf, metricsConfig, mapping); } else if (format.contains("orc")) { - // TODO: use NameMapping in listOrcPartition - return listOrcPartition(partition, uri, spec, conf, metricsConfig); + return listOrcPartition(partition, uri, spec, conf, metricsConfig, mapping); } else { throw new UnsupportedOperationException("Unknown partition format: " + format); } @@ -396,7 +395,7 @@ private static List listParquetPartition(Map partition private static List listOrcPartition(Map partitionPath, String partitionUri, PartitionSpec spec, Configuration conf, - MetricsConfig metricsSpec) { + MetricsConfig metricsSpec, NameMapping mapping) { try { Path partition = new Path(partitionUri); FileSystem fs = partition.getFileSystem(conf); @@ -404,7 +403,8 @@ private static List listOrcPartition(Map partitionPath return Arrays.stream(fs.listStatus(partition, HIDDEN_PATH_FILTER)) .filter(FileStatus::isFile) .map(stat -> { - Metrics metrics = OrcMetrics.fromInputFile(HadoopInputFile.fromPath(stat.getPath(), conf), metricsSpec); + Metrics metrics = OrcMetrics.fromInputFile(HadoopInputFile.fromPath(stat.getPath(), conf), + metricsSpec, mapping); String partitionKey = spec.fields().stream() .map(PartitionField::name) .map(name -> String.format("%s=%s", name, partitionPath.get(name))) diff --git a/spark2/src/test/java/org/apache/iceberg/spark/source/TestSparkTableUtil.java b/spark2/src/test/java/org/apache/iceberg/spark/source/TestSparkTableUtil.java index 3c43c4250..fe948a4dd 100644 --- a/spark2/src/test/java/org/apache/iceberg/spark/source/TestSparkTableUtil.java +++ b/spark2/src/test/java/org/apache/iceberg/spark/source/TestSparkTableUtil.java @@ -22,9 +22,9 @@ import java.io.File; import java.io.IOException; import java.util.List; -import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.iceberg.FileFormat; import org.apache.iceberg.Schema; import org.apache.iceberg.Table; import org.apache.iceberg.hadoop.HadoopTables; @@ -47,32 +47,31 @@ import org.junit.After; import org.junit.AfterClass; import org.junit.Assert; +import org.junit.Assume; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; +import org.junit.experimental.runners.Enclosed; import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import static org.apache.iceberg.TableProperties.DEFAULT_NAME_MAPPING; import static org.apache.iceberg.TableProperties.PARQUET_VECTORIZATION_ENABLED; import static org.apache.iceberg.types.Types.NestedField.optional; +@RunWith(Enclosed.class) public class TestSparkTableUtil extends HiveTableBaseTest { - private static final Configuration CONF = HiveTableBaseTest.hiveConf; - private static final String tableName = "hive_table"; - private static final String dbName = HiveTableBaseTest.DB_NAME; - private static final String qualifiedTableName = String.format("%s.%s", dbName, tableName); - private static final Path tableLocationPath = HiveTableBaseTest.getTableLocationPath(tableName); - private static final String tableLocationStr = tableLocationPath.toString(); + private static final String TABLE_NAME = "hive_table"; + private static final String QUALIFIED_TABLE_NAME = String.format("%s.%s", HiveTableBaseTest.DB_NAME, TABLE_NAME); + private static final Path TABLE_LOCATION_PATH = HiveTableBaseTest.getTableLocationPath(TABLE_NAME); + private static final String TABLE_LOCATION_STR = TABLE_LOCATION_PATH.toString(); private static SparkSession spark = null; - @Rule - public TemporaryFolder temp = new TemporaryFolder(); - - @BeforeClass public static void startSpark() { - String metastoreURI = CONF.get(HiveConf.ConfVars.METASTOREURIS.varname); + String metastoreURI = HiveTableBaseTest.hiveConf.get(HiveConf.ConfVars.METASTOREURIS.varname); // Create a spark session. TestSparkTableUtil.spark = SparkSession.builder().master("local[2]") @@ -92,9 +91,7 @@ public static void stopSpark() { currentSpark.stop(); } - @Before - public void before() { - + static void loadData(FileFormat fileFormat) { // Create a hive table. SQLContext sc = new SQLContext(TestSparkTableUtil.spark); @@ -102,8 +99,9 @@ public void before() { "CREATE TABLE %s (\n" + " id int COMMENT 'unique id'\n" + ")\n" + - " PARTITIONED BY (data string)\n" + - " LOCATION '%s'", qualifiedTableName, tableLocationStr) + "PARTITIONED BY (data string)\n" + + "STORED AS %s\n" + + "LOCATION '%s'", QUALIFIED_TABLE_NAME, fileFormat, TABLE_LOCATION_STR) ); List expected = Lists.newArrayList( @@ -116,237 +114,292 @@ public void before() { df.select("id", "data").orderBy("data").write() .mode("append") - .insertInto(qualifiedTableName); + .insertInto(QUALIFIED_TABLE_NAME); } - @After - public void after() throws IOException { + static void cleanupData() throws IOException { // Drop the hive table. SQLContext sc = new SQLContext(TestSparkTableUtil.spark); - sc.sql(String.format("DROP TABLE IF EXISTS %s", qualifiedTableName)); + sc.sql(String.format("DROP TABLE IF EXISTS %s", QUALIFIED_TABLE_NAME)); // Delete the data corresponding to the table. - tableLocationPath.getFileSystem(CONF).delete(tableLocationPath, true); - } - - @Test - public void testPartitionScan() { - List partitions = SparkTableUtil.getPartitions(spark, qualifiedTableName); - Assert.assertEquals("There should be 3 partitions", 3, partitions.size()); - - Dataset partitionDF = SparkTableUtil.partitionDF(spark, qualifiedTableName); - Assert.assertEquals("There should be 3 partitions", 3, partitionDF.count()); - } - - @Test - public void testPartitionScanByFilter() { - List partitions = SparkTableUtil.getPartitionsByFilter(spark, qualifiedTableName, "data = 'a'"); - Assert.assertEquals("There should be 1 matching partition", 1, partitions.size()); - - Dataset partitionDF = SparkTableUtil.partitionDFByFilter(spark, qualifiedTableName, "data = 'a'"); - Assert.assertEquals("There should be 1 matching partition", 1, partitionDF.count()); - } - - @Test - public void testImportPartitionedTable() throws Exception { - File location = temp.newFolder("partitioned_table"); - spark.table(qualifiedTableName).write().mode("overwrite").partitionBy("data").format("parquet") - .saveAsTable("test_partitioned_table"); - TableIdentifier source = spark.sessionState().sqlParser() - .parseTableIdentifier("test_partitioned_table"); - HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); - Table table = tables.create(SparkSchemaUtil.schemaForTable(spark, qualifiedTableName), - SparkSchemaUtil.specForTable(spark, qualifiedTableName), - ImmutableMap.of(), - location.getCanonicalPath()); - File stagingDir = temp.newFolder("staging-dir"); - SparkTableUtil.importSparkTable(spark, source, table, stagingDir.toString()); - long count = spark.read().format("iceberg").load(location.toString()).count(); - Assert.assertEquals("three values ", 3, count); - } - - @Test - public void testImportUnpartitionedTable() throws Exception { - File location = temp.newFolder("unpartitioned_table"); - spark.table(qualifiedTableName).write().mode("overwrite").format("parquet") - .saveAsTable("test_unpartitioned_table"); - TableIdentifier source = spark.sessionState().sqlParser() - .parseTableIdentifier("test_unpartitioned_table"); - HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); - Table table = tables.create(SparkSchemaUtil.schemaForTable(spark, qualifiedTableName), - SparkSchemaUtil.specForTable(spark, qualifiedTableName), - ImmutableMap.of(), - location.getCanonicalPath()); - File stagingDir = temp.newFolder("staging-dir"); - SparkTableUtil.importSparkTable(spark, source, table, stagingDir.toString()); - long count = spark.read().format("iceberg").load(location.toString()).count(); - Assert.assertEquals("three values ", 3, count); - } - - @Test - public void testImportAsHiveTable() throws Exception { - spark.table(qualifiedTableName).write().mode("overwrite").format("parquet") - .saveAsTable("unpartitioned_table"); - TableIdentifier source = new TableIdentifier("unpartitioned_table"); - Table table = catalog.createTable( - org.apache.iceberg.catalog.TableIdentifier.of(DB_NAME, "test_unpartitioned_table"), - SparkSchemaUtil.schemaForTable(spark, "unpartitioned_table"), - SparkSchemaUtil.specForTable(spark, "unpartitioned_table")); - File stagingDir = temp.newFolder("staging-dir"); - SparkTableUtil.importSparkTable(spark, source, table, stagingDir.toString()); - long count1 = spark.read().format("iceberg").load(DB_NAME + ".test_unpartitioned_table").count(); - Assert.assertEquals("three values ", 3, count1); - - spark.table(qualifiedTableName).write().mode("overwrite").partitionBy("data").format("parquet") - .saveAsTable("partitioned_table"); - source = new TableIdentifier("partitioned_table"); - table = catalog.createTable( - org.apache.iceberg.catalog.TableIdentifier.of(DB_NAME, "test_partitioned_table"), - SparkSchemaUtil.schemaForTable(spark, "partitioned_table"), - SparkSchemaUtil.specForTable(spark, "partitioned_table")); - - SparkTableUtil.importSparkTable(spark, source, table, stagingDir.toString()); - long count2 = spark.read().format("iceberg").load(DB_NAME + ".test_partitioned_table").count(); - Assert.assertEquals("three values ", 3, count2); - } - - @Test - public void testImportWithNameMapping() throws Exception { - spark.table(qualifiedTableName).write().mode("overwrite").format("parquet") - .saveAsTable("original_table"); - - // The field is different so that it will project with name mapping - Schema filteredSchema = new Schema( - optional(1, "data", Types.StringType.get()) - ); - - NameMapping nameMapping = MappingUtil.create(filteredSchema); - - TableIdentifier source = new TableIdentifier("original_table"); - Table table = catalog.createTable( - org.apache.iceberg.catalog.TableIdentifier.of(DB_NAME, "target_table"), - filteredSchema, - SparkSchemaUtil.specForTable(spark, "original_table")); - - table.updateProperties().set(DEFAULT_NAME_MAPPING, NameMappingParser.toJson(nameMapping)).commit(); - - File stagingDir = temp.newFolder("staging-dir"); - SparkTableUtil.importSparkTable(spark, source, table, stagingDir.toString()); - - // The filter invoke the metric/dictionary row group filter in which it project schema - // with name mapping again to match the metric read from footer. - List actual = spark.read().format("iceberg").load(DB_NAME + ".target_table") - .select("data") - .sort("data") - .filter("data >= 'b'") - .as(Encoders.STRING()) - .collectAsList(); - - List expected = Lists.newArrayList("b", "c"); - - Assert.assertEquals(expected, actual); - } - - @Test - public void testImportWithNameMappingForVectorizedParquetReader() throws Exception { - spark.table(qualifiedTableName).write().mode("overwrite").format("parquet") - .saveAsTable("original_table"); - - // The field is different so that it will project with name mapping - Schema filteredSchema = new Schema( - optional(1, "data", Types.StringType.get()) - ); - - NameMapping nameMapping = MappingUtil.create(filteredSchema); - - TableIdentifier source = new TableIdentifier("original_table"); - Table table = catalog.createTable( - org.apache.iceberg.catalog.TableIdentifier.of(DB_NAME, "target_table_for_vectorization"), - filteredSchema, - SparkSchemaUtil.specForTable(spark, "original_table")); - - table.updateProperties() - .set(DEFAULT_NAME_MAPPING, NameMappingParser.toJson(nameMapping)) - .set(PARQUET_VECTORIZATION_ENABLED, "true") - .commit(); - - File stagingDir = temp.newFolder("staging-dir"); - SparkTableUtil.importSparkTable(spark, source, table, stagingDir.toString()); - - // The filter invoke the metric/dictionary row group filter in which it project schema - // with name mapping again to match the metric read from footer. - List actual = spark.read().format("iceberg") - .load(DB_NAME + ".target_table_for_vectorization") - .select("data") - .sort("data") - .filter("data >= 'b'") - .as(Encoders.STRING()) - .collectAsList(); - - List expected = Lists.newArrayList("b", "c"); - - Assert.assertEquals(expected, actual); + TABLE_LOCATION_PATH.getFileSystem(HiveTableBaseTest.hiveConf).delete(TABLE_LOCATION_PATH, true); } - @Test - public void testImportPartitionedWithWhitespace() throws Exception { - String partitionCol = "dAtA sPaced"; - String spacedTableName = "whitespacetable"; - String whiteSpaceKey = "some key value"; - - List spacedRecords = Lists.newArrayList(new SimpleRecord(1, whiteSpaceKey)); - - File icebergLocation = temp.newFolder("partitioned_table"); - - spark.createDataFrame(spacedRecords, SimpleRecord.class) - .withColumnRenamed("data", partitionCol) - .write().mode("overwrite").partitionBy(partitionCol).format("parquet") - .saveAsTable(spacedTableName); - - TableIdentifier source = spark.sessionState().sqlParser() - .parseTableIdentifier(spacedTableName); - HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); - Table table = tables.create(SparkSchemaUtil.schemaForTable(spark, spacedTableName), - SparkSchemaUtil.specForTable(spark, spacedTableName), - ImmutableMap.of(), - icebergLocation.getCanonicalPath()); - File stagingDir = temp.newFolder("staging-dir"); - SparkTableUtil.importSparkTable(spark, source, table, stagingDir.toString()); - List results = spark.read().format("iceberg").load(icebergLocation.toString()) - .withColumnRenamed(partitionCol, "data") - .as(Encoders.bean(SimpleRecord.class)) - .collectAsList(); - - Assert.assertEquals("Data should match", spacedRecords, results); + @RunWith(Parameterized.class) + public static class TableImport { + + private final FileFormat format; + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + @Parameterized.Parameters + public static Object[][] parameters() { + return new Object[][] { + new Object[] { "parquet" }, + new Object[] { "orc" } + }; + } + + public TableImport(String format) { + this.format = FileFormat.valueOf(format.toUpperCase()); + } + + @Before + public void before() { + loadData(format); + } + + @After + public void after() throws IOException { + cleanupData(); + } + + @Test + public void testImportPartitionedTable() throws Exception { + File location = temp.newFolder("partitioned_table"); + spark.table(QUALIFIED_TABLE_NAME).write().mode("overwrite").partitionBy("data").format(format.toString()) + .saveAsTable("test_partitioned_table"); + TableIdentifier source = spark.sessionState().sqlParser() + .parseTableIdentifier("test_partitioned_table"); + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + Schema tableSchema = SparkSchemaUtil.schemaForTable(spark, QUALIFIED_TABLE_NAME); + Table table = tables.create(tableSchema, + SparkSchemaUtil.specForTable(spark, QUALIFIED_TABLE_NAME), + ImmutableMap.of(), + location.getCanonicalPath()); + File stagingDir = temp.newFolder("staging-dir"); + SparkTableUtil.importSparkTable(spark, source, table, stagingDir.toString()); + long count = spark.read().format("iceberg").load(location.toString()).count(); + Assert.assertEquals("three values ", 3, count); + } + + @Test + public void testImportUnpartitionedTable() throws Exception { + File location = temp.newFolder("unpartitioned_table"); + spark.table(QUALIFIED_TABLE_NAME).write().mode("overwrite").format(format.toString()) + .saveAsTable("test_unpartitioned_table"); + TableIdentifier source = spark.sessionState().sqlParser() + .parseTableIdentifier("test_unpartitioned_table"); + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + Table table = tables.create(SparkSchemaUtil.schemaForTable(spark, QUALIFIED_TABLE_NAME), + SparkSchemaUtil.specForTable(spark, QUALIFIED_TABLE_NAME), + ImmutableMap.of(), + location.getCanonicalPath()); + File stagingDir = temp.newFolder("staging-dir"); + SparkTableUtil.importSparkTable(spark, source, table, stagingDir.toString()); + long count = spark.read().format("iceberg").load(location.toString()).count(); + Assert.assertEquals("three values ", 3, count); + } + + @Test + public void testImportAsHiveTable() throws Exception { + spark.table(QUALIFIED_TABLE_NAME).write().mode("overwrite").format(format.toString()) + .saveAsTable("unpartitioned_table"); + TableIdentifier source = new TableIdentifier("unpartitioned_table"); + org.apache.iceberg.catalog.TableIdentifier testUnpartitionedTableId = + org.apache.iceberg.catalog.TableIdentifier.of(DB_NAME, "test_unpartitioned_table_" + format); + File stagingDir = temp.newFolder("staging-dir"); + Table table = catalog.createTable( + testUnpartitionedTableId, + SparkSchemaUtil.schemaForTable(spark, "unpartitioned_table"), + SparkSchemaUtil.specForTable(spark, "unpartitioned_table")); + + SparkTableUtil.importSparkTable(spark, source, table, stagingDir.toString()); + long count1 = spark.read().format("iceberg").load(testUnpartitionedTableId.toString()).count(); + Assert.assertEquals("three values ", 3, count1); + + spark.table(QUALIFIED_TABLE_NAME).write().mode("overwrite").partitionBy("data").format(format.toString()) + .saveAsTable("partitioned_table"); + + source = new TableIdentifier("partitioned_table"); + org.apache.iceberg.catalog.TableIdentifier testPartitionedTableId = + org.apache.iceberg.catalog.TableIdentifier.of(DB_NAME, "test_partitioned_table_" + format); + table = catalog.createTable( + testPartitionedTableId, + SparkSchemaUtil.schemaForTable(spark, "partitioned_table"), + SparkSchemaUtil.specForTable(spark, "partitioned_table")); + + SparkTableUtil.importSparkTable(spark, source, table, stagingDir.toString()); + long count2 = spark.read().format("iceberg").load(testPartitionedTableId.toString()).count(); + Assert.assertEquals("three values ", 3, count2); + } + + @Test + public void testImportWithNameMapping() throws Exception { + spark.table(QUALIFIED_TABLE_NAME).write().mode("overwrite").format(format.toString()) + .saveAsTable("original_table"); + + // The field is different so that it will project with name mapping + Schema filteredSchema = new Schema( + optional(1, "data", Types.StringType.get()) + ); + + NameMapping nameMapping = MappingUtil.create(filteredSchema); + + String targetTableName = "target_table_" + format; + TableIdentifier source = new TableIdentifier("original_table"); + org.apache.iceberg.catalog.TableIdentifier targetTable = + org.apache.iceberg.catalog.TableIdentifier.of(DB_NAME, targetTableName); + Table table = catalog.createTable( + targetTable, + filteredSchema, + SparkSchemaUtil.specForTable(spark, "original_table")); + + table.updateProperties().set(DEFAULT_NAME_MAPPING, NameMappingParser.toJson(nameMapping)).commit(); + + File stagingDir = temp.newFolder("staging-dir"); + SparkTableUtil.importSparkTable(spark, source, table, stagingDir.toString()); + + // The filter invoke the metric/dictionary row group filter in which it project schema + // with name mapping again to match the metric read from footer. + List actual = spark.read().format("iceberg").load(targetTable.toString()) + .select("data") + .sort("data") + .filter("data >= 'b'") + .as(Encoders.STRING()) + .collectAsList(); + + List expected = Lists.newArrayList("b", "c"); + + Assert.assertEquals(expected, actual); + } + + @Test + public void testImportWithNameMappingForVectorizedParquetReader() throws Exception { + Assume.assumeTrue("Applies only to parquet format.", + FileFormat.PARQUET == format); + spark.table(QUALIFIED_TABLE_NAME).write().mode("overwrite").format(format.toString()) + .saveAsTable("original_table"); + + // The field is different so that it will project with name mapping + Schema filteredSchema = new Schema( + optional(1, "data", Types.StringType.get()) + ); + + NameMapping nameMapping = MappingUtil.create(filteredSchema); + + TableIdentifier source = new TableIdentifier("original_table"); + Table table = catalog.createTable( + org.apache.iceberg.catalog.TableIdentifier.of(DB_NAME, "target_table_for_vectorization"), + filteredSchema, + SparkSchemaUtil.specForTable(spark, "original_table")); + + table.updateProperties() + .set(DEFAULT_NAME_MAPPING, NameMappingParser.toJson(nameMapping)) + .set(PARQUET_VECTORIZATION_ENABLED, "true") + .commit(); + + File stagingDir = temp.newFolder("staging-dir"); + SparkTableUtil.importSparkTable(spark, source, table, stagingDir.toString()); + + // The filter invoke the metric/dictionary row group filter in which it project schema + // with name mapping again to match the metric read from footer. + List actual = spark.read().format("iceberg") + .load(DB_NAME + ".target_table_for_vectorization") + .select("data") + .sort("data") + .filter("data >= 'b'") + .as(Encoders.STRING()) + .collectAsList(); + + List expected = Lists.newArrayList("b", "c"); + + Assert.assertEquals(expected, actual); + } + + @Test + public void testImportPartitionedWithWhitespace() throws Exception { + String partitionCol = "dAtA sPaced"; + String spacedTableName = "whitespacetable"; + String whiteSpaceKey = "some key value"; + + List spacedRecords = Lists.newArrayList(new SimpleRecord(1, whiteSpaceKey)); + + File icebergLocation = temp.newFolder("partitioned_table"); + + spark.createDataFrame(spacedRecords, SimpleRecord.class) + .withColumnRenamed("data", partitionCol) + .write().mode("overwrite").partitionBy(partitionCol).format(format.toString()) + .saveAsTable(spacedTableName); + + TableIdentifier source = spark.sessionState().sqlParser() + .parseTableIdentifier(spacedTableName); + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + Table table = tables.create(SparkSchemaUtil.schemaForTable(spark, spacedTableName), + SparkSchemaUtil.specForTable(spark, spacedTableName), + ImmutableMap.of(), + icebergLocation.getCanonicalPath()); + File stagingDir = temp.newFolder("staging-dir"); + SparkTableUtil.importSparkTable(spark, source, table, stagingDir.toString()); + List results = spark.read().format("iceberg").load(icebergLocation.toString()) + .withColumnRenamed(partitionCol, "data") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + + Assert.assertEquals("Data should match", spacedRecords, results); + } + + @Test + public void testImportUnpartitionedWithWhitespace() throws Exception { + String spacedTableName = "whitespacetable_" + format; + String whiteSpaceKey = "some key value"; + + List spacedRecords = Lists.newArrayList(new SimpleRecord(1, whiteSpaceKey)); + + File whiteSpaceOldLocation = temp.newFolder("white space location"); + File icebergLocation = temp.newFolder("partitioned_table"); + + spark.createDataFrame(spacedRecords, SimpleRecord.class) + .write().mode("overwrite").format(format.toString()).save(whiteSpaceOldLocation.getPath()); + + spark.catalog().createExternalTable(spacedTableName, whiteSpaceOldLocation.getPath(), format.toString()); + + TableIdentifier source = spark.sessionState().sqlParser() + .parseTableIdentifier(spacedTableName); + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + Table table = tables.create(SparkSchemaUtil.schemaForTable(spark, spacedTableName), + SparkSchemaUtil.specForTable(spark, spacedTableName), + ImmutableMap.of(), + icebergLocation.getCanonicalPath()); + File stagingDir = temp.newFolder("staging-dir"); + SparkTableUtil.importSparkTable(spark, source, table, stagingDir.toString()); + List results = spark.read().format("iceberg").load(icebergLocation.toString()) + .as(Encoders.bean(SimpleRecord.class)).collectAsList(); + + Assert.assertEquals("Data should match", spacedRecords, results); + } } - @Test - public void testImportUnpartitionedWithWhitespace() throws Exception { - String spacedTableName = "whitespacetable"; - String whiteSpaceKey = "some key value"; + public static class PartitionScan { - List spacedRecords = Lists.newArrayList(new SimpleRecord(1, whiteSpaceKey)); + @Before + public void before() { + loadData(FileFormat.PARQUET); + } - File whiteSpaceOldLocation = temp.newFolder("white space location"); - File icebergLocation = temp.newFolder("partitioned_table"); + @After + public void after() throws IOException { + cleanupData(); + } - spark.createDataFrame(spacedRecords, SimpleRecord.class) - .write().mode("overwrite").parquet(whiteSpaceOldLocation.getPath()); + @Test + public void testPartitionScan() { + List partitions = SparkTableUtil.getPartitions(spark, QUALIFIED_TABLE_NAME); + Assert.assertEquals("There should be 3 partitions", 3, partitions.size()); - spark.catalog().createExternalTable(spacedTableName, whiteSpaceOldLocation.getPath()); + Dataset partitionDF = SparkTableUtil.partitionDF(spark, QUALIFIED_TABLE_NAME); + Assert.assertEquals("There should be 3 partitions", 3, partitionDF.count()); + } - TableIdentifier source = spark.sessionState().sqlParser() - .parseTableIdentifier(spacedTableName); - HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); - Table table = tables.create(SparkSchemaUtil.schemaForTable(spark, spacedTableName), - SparkSchemaUtil.specForTable(spark, spacedTableName), - ImmutableMap.of(), - icebergLocation.getCanonicalPath()); - File stagingDir = temp.newFolder("staging-dir"); - SparkTableUtil.importSparkTable(spark, source, table, stagingDir.toString()); - List results = spark.read().format("iceberg").load(icebergLocation.toString()) - .as(Encoders.bean(SimpleRecord.class)).collectAsList(); + @Test + public void testPartitionScanByFilter() { + List partitions = SparkTableUtil.getPartitionsByFilter(spark, QUALIFIED_TABLE_NAME, "data = 'a'"); + Assert.assertEquals("There should be 1 matching partition", 1, partitions.size()); - Assert.assertEquals("Data should match", spacedRecords, results); + Dataset partitionDF = SparkTableUtil.partitionDFByFilter(spark, QUALIFIED_TABLE_NAME, "data = 'a'"); + Assert.assertEquals("There should be 1 matching partition", 1, partitionDF.count()); + } } }