From bb2dfdd4a8aa2b2e918f72899fdd62d1540c95cf Mon Sep 17 00:00:00 2001 From: Alva Bandy Date: Fri, 21 Jun 2024 19:14:07 -0400 Subject: [PATCH] GH-42245: [Swift] Ensure map behavior is the same for all key types --- swift/Arrow/Sources/Arrow/ArrowDecoder.swift | 68 +++++++++++++------ .../Arrow/Tests/ArrowTests/CodableTests.swift | 43 +++++++----- 2 files changed, 71 insertions(+), 40 deletions(-) diff --git a/swift/Arrow/Sources/Arrow/ArrowDecoder.swift b/swift/Arrow/Sources/Arrow/ArrowDecoder.swift index 518b4e9c32970..61becdc73fbf3 100644 --- a/swift/Arrow/Sources/Arrow/ArrowDecoder.swift +++ b/swift/Arrow/Sources/Arrow/ArrowDecoder.swift @@ -19,6 +19,7 @@ import Foundation public class ArrowDecoder: Decoder { var rbIndex: UInt = 0 + var singleRBCol: Int = 0 public var codingPath: [CodingKey] = [] public var userInfo: [CodingUserInfoKey: Any] = [:] public let rb: RecordBatch @@ -47,6 +48,25 @@ public class ArrowDecoder: Decoder { self.nameToCol = colMapping } + public func decode(_ type: [T: U].Type) throws -> [T: U] { + var output = [T: U]() + if rb.columnCount != 2 { + throw ArrowError.invalid("RecordBatch column count of 2 is required to decode to map") + } + + for index in 0..(_ type: T.Type) throws -> [T] { var output = [T]() for index in 0..() throws -> T { + let array: AnyArray = try self.getCol(self.singleRBCol) + return array.asAny(self.rbIndex) as! T // swiftlint:disable:this force_cast + } + func isNull(_ key: CodingKey) throws -> Bool { let array: AnyArray = try self.getCol(key.stringValue) return array.asAny(self.rbIndex) == nil @@ -114,6 +139,11 @@ public class ArrowDecoder: Decoder { let array: AnyArray = try self.getCol(col) return array.asAny(self.rbIndex) == nil } + + func isNullSingleValue() throws -> Bool { + let array: AnyArray = try self.getCol(self.singleRBCol) + return array.asAny(self.rbIndex) == nil + } } private struct ArrowUnkeyedDecoding: UnkeyedDecodingContainer { @@ -252,7 +282,7 @@ private struct ArrowKeyedDecoding: KeyedDecodingContainerProtoco } func decode(_ type: T.Type, forKey key: Key) throws -> T where T: Decodable { - if type == Date.self { + if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self { return try self.decoder.doDecode(key)! } else { throw ArrowError.invalid("Type \(type) is currently not supported") @@ -290,26 +320,26 @@ private struct ArrowSingleValueDecoding: SingleValueDecodingContainer { func decodeNil() -> Bool { do { - return try self.decoder.isNull(0) + return try self.decoder.isNullSingleValue() } catch { return false } } func decode(_ type: Bool.Type) throws -> Bool { - return try self.decoder.doDecode(0)! + return try self.decoder.doDecodeSingleValue() } func decode(_ type: String.Type) throws -> String { - return try self.decoder.doDecode(0)! + return try self.decoder.doDecodeSingleValue() } func decode(_ type: Double.Type) throws -> Double { - return try self.decoder.doDecode(0)! + return try self.decoder.doDecodeSingleValue() } func decode(_ type: Float.Type) throws -> Float { - return try self.decoder.doDecode(0)! + return try self.decoder.doDecodeSingleValue() } func decode(_ type: Int.Type) throws -> Int { @@ -318,19 +348,19 @@ private struct ArrowSingleValueDecoding: SingleValueDecodingContainer { } func decode(_ type: Int8.Type) throws -> Int8 { - return try self.decoder.doDecode(0)! + return try self.decoder.doDecodeSingleValue() } func decode(_ type: Int16.Type) throws -> Int16 { - return try self.decoder.doDecode(0)! + return try self.decoder.doDecodeSingleValue() } func decode(_ type: Int32.Type) throws -> Int32 { - return try self.decoder.doDecode(0)! + return try self.decoder.doDecodeSingleValue() } func decode(_ type: Int64.Type) throws -> Int64 { - return try self.decoder.doDecode(0)! + return try self.decoder.doDecodeSingleValue()! } func decode(_ type: UInt.Type) throws -> UInt { @@ -339,30 +369,24 @@ private struct ArrowSingleValueDecoding: SingleValueDecodingContainer { } func decode(_ type: UInt8.Type) throws -> UInt8 { - return try self.decoder.doDecode(0)! + return try self.decoder.doDecodeSingleValue() } func decode(_ type: UInt16.Type) throws -> UInt16 { - return try self.decoder.doDecode(0)! + return try self.decoder.doDecodeSingleValue() } func decode(_ type: UInt32.Type) throws -> UInt32 { - return try self.decoder.doDecode(0)! + return try self.decoder.doDecodeSingleValue() } func decode(_ type: UInt64.Type) throws -> UInt64 { - return try self.decoder.doDecode(0)! + return try self.decoder.doDecodeSingleValue() } func decode(_ type: T.Type) throws -> T where T: Decodable { - if type == Int8.self || type == Int16.self || - type == Int32.self || type == Int64.self || - type == UInt8.self || type == UInt16.self || - type == UInt32.self || type == UInt64.self || - type == String.self || type == Double.self || - type == Float.self || type == Date.self || - type == Bool.self { - return try self.decoder.doDecode(0)! + if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self { + return try self.decoder.doDecodeSingleValue() } else { throw ArrowError.invalid("Type \(type) is currently not supported") } diff --git a/swift/Arrow/Tests/ArrowTests/CodableTests.swift b/swift/Arrow/Tests/ArrowTests/CodableTests.swift index 160beea17c9fa..400faa9f2907f 100644 --- a/swift/Arrow/Tests/ArrowTests/CodableTests.swift +++ b/swift/Arrow/Tests/ArrowTests/CodableTests.swift @@ -166,35 +166,45 @@ final class CodableTests: XCTestCase { } } - func testArrowUnkeyedDecoderWithoutNull() throws { + func testArrowMapDecoderWithoutNull() throws { let int8Builder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() let stringBuilder = try ArrowArrayBuilders.loadStringArrayBuilder() int8Builder.append(10, 11, 12, 13) - stringBuilder.append("test0", "test1", "test2", "test3") - let result = RecordBatch.Builder() + stringBuilder.append("test10", "test11", "test12", "test13") + switch RecordBatch.Builder() .addColumn("propInt8", arrowArray: try int8Builder.toHolder()) .addColumn("propString", arrowArray: try stringBuilder.toHolder()) - .finish() - switch result { + .finish() { case .success(let rb): let decoder = ArrowDecoder(rb) let testData = try decoder.decode([Int8: String].self) - var index: Int8 = 0 for data in testData { - let str = data[10 + index] - XCTAssertEqual(str, "test\(index)") - index += 1 + XCTAssertEqual("test\(data.key)", data.value) + } + case .failure(let err): + throw err + } + + switch RecordBatch.Builder() + .addColumn("propString", arrowArray: try stringBuilder.toHolder()) + .addColumn("propInt8", arrowArray: try int8Builder.toHolder()) + .finish() { + case .success(let rb): + let decoder = ArrowDecoder(rb) + let testData = try decoder.decode([String: Int8].self) + for data in testData { + XCTAssertEqual("test\(data.value)", data.key) } case .failure(let err): throw err } } - func testArrowUnkeyedDecoderWithNull() throws { + func testArrowMapDecoderWithNull() throws { let int8Builder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() let stringWNilBuilder = try ArrowArrayBuilders.loadStringArrayBuilder() int8Builder.append(10, 11, 12, 13) - stringWNilBuilder.append(nil, "test1", nil, "test3") + stringWNilBuilder.append(nil, "test11", nil, "test13") let resultWNil = RecordBatch.Builder() .addColumn("propInt8", arrowArray: try int8Builder.toHolder()) .addColumn("propString", arrowArray: try stringWNilBuilder.toHolder()) @@ -203,19 +213,16 @@ final class CodableTests: XCTestCase { case .success(let rb): let decoder = ArrowDecoder(rb) let testData = try decoder.decode([Int8: String?].self) - var index: Int8 = 0 for data in testData { - let str = data[10 + index] - if index % 2 == 0 { - XCTAssertNil(str!) + let str = data.value + if data.key % 2 == 0 { + XCTAssertNil(str) } else { - XCTAssertEqual(str, "test\(index)") + XCTAssertEqual(str, "test\(data.key)") } - index += 1 } case .failure(let err): throw err } - } }