diff --git a/__tests__/series.test.ts b/__tests__/series.test.ts index 7e4570957..bd4571188 100644 --- a/__tests__/series.test.ts +++ b/__tests__/series.test.ts @@ -1,5 +1,5 @@ /* eslint-disable newline-per-chained-call */ -import pl from "@polars"; +import pl, { DataType } from "@polars"; import Chance from "chance"; describe("from lists", () => { @@ -867,3 +867,23 @@ describe("series struct", () => { expect(actual).toEqual(expected); }); }); +describe("generics", () => { + const series = pl.Series([1, 2, 3]); + + test("dtype", () => { + expect(series.dtype).toStrictEqual(DataType.Float64); + }); + test("to array", () => { + const arr = series.toArray(); + expect(arr).toStrictEqual([1, 2, 3]); + + const arr2 = [...series]; + expect(arr2).toStrictEqual([1, 2, 3]); + }); + test("to object", () => { + const obj = series.toObject(); + expect<{ name: string; datatype: "Float64"; values: number[] }>( + obj, + ).toMatchObject({ name: "", datatype: "Float64", values: [1, 2, 3] }); + }); +}); diff --git a/polars/datatypes/datatype.ts b/polars/datatypes/datatype.ts index 71d183174..1294c631a 100644 --- a/polars/datatypes/datatype.ts +++ b/polars/datatypes/datatype.ts @@ -2,7 +2,7 @@ import { Field } from "./field"; export abstract class DataType { get variant() { - return this.constructor.name; + return this.constructor.name as DataTypeName; } protected identity = "DataType"; protected get inner(): null | any[] { @@ -406,3 +406,73 @@ export namespace DataType { return DataType[variant](...inner); } } + +export type DataTypeName = + | "Null" + | "Bool" + | "Int8" + | "Int16" + | "Int32" + | "Int64" + | "UInt8" + | "UInt16" + | "UInt32" + | "UInt64" + | "Float32" + | "Float64" + | "Decimal" + | "Date" + | "Datetime" + | "Utf8" + | "Categorical" + | "List" + | "Struct"; + +export type JsType = number | boolean | string; +export type JsToDtype = T extends number + ? DataType.Float64 + : T extends boolean + ? DataType.Bool + : T extends string + ? DataType.Utf8 + : never; +export type DTypeToJs = T extends DataType.Decimal + ? bigint + : T extends DataType.Float64 + ? number + : T extends DataType.Int64 + ? bigint + : T extends DataType.Int32 + ? number + : T extends DataType.Bool + ? boolean + : T extends DataType.Utf8 + ? string + : never; +export type DtypeToJsName = T extends DataType.Decimal + ? "Decimal" + : T extends DataType.Float64 + ? "Float64" + : T extends DataType.Float32 + ? "Float32" + : T extends DataType.Int64 + ? "Int64" + : T extends DataType.Int32 + ? "Int32" + : T extends DataType.Int16 + ? "Int16" + : T extends DataType.Int8 + ? "Int8" + : T extends DataType.UInt64 + ? "UInt64" + : T extends DataType.UInt32 + ? "UInt32" + : T extends DataType.UInt16 + ? "UInt16" + : T extends DataType.UInt8 + ? "UInt8" + : T extends DataType.Bool + ? "Bool" + : T extends DataType.Utf8 + ? "Utf8" + : never; diff --git a/polars/lazy/functions.ts b/polars/lazy/functions.ts index b5943ce56..1380871ce 100644 --- a/polars/lazy/functions.ts +++ b/polars/lazy/functions.ts @@ -100,7 +100,7 @@ import pli from "../internals/polars_internal"; */ export function col(col: string | string[] | Series | DataType): Expr { if (Series.isSeries(col)) { - col = col.toArray(); + col = col.toArray() as string[]; } if (Array.isArray(col)) { return _Expr(pli.cols(col)); diff --git a/polars/series/index.ts b/polars/series/index.ts index 8faa18e7b..8f1eaeeb7 100644 --- a/polars/series/index.ts +++ b/polars/series/index.ts @@ -12,54 +12,59 @@ import type { Comparison, Cumulative, Deserialize, + EwmOps, Rolling, Round, Sample, Serialize, - EwmOps, } from "../shared_traits"; import { col } from "../lazy/functions"; import type { InterpolationMethod, RankMethod } from "../types"; +import type { + DTypeToJs, + DtypeToJsName, + JsToDtype, + JsType, +} from "@polars/datatypes/datatype"; const inspect = Symbol.for("nodejs.util.inspect.custom"); /** * A Series represents a single column in a polars DataFrame. */ -export interface Series - extends ArrayLike, - Rolling, - Arithmetic, - Comparison, - Cumulative, - Round, - Sample, - EwmOps, +export interface Series + extends ArrayLike, + Rolling>, + Arithmetic>, + Comparison>, + Cumulative>, + Round>, + Sample>, + EwmOps>, Serialize { inner(): any; name: string; - dtype: DataType; + dtype: T; str: StringNamespace; lst: ListNamespace; struct: SeriesStructFunctions; date: SeriesDateFunctions; [inspect](): string; - [Symbol.iterator](): IterableIterator; + [Symbol.iterator](): IterableIterator>; // inner(): JsSeries - bitand(other: Series): Series; - bitor(other: Series): Series; - bitxor(other: Series): Series; + bitand(other: Series): Series; + bitor(other: Series): Series; + bitxor(other: Series): Series; /** * Take absolute values */ - abs(): Series; + abs(): Series; /** * __Rename this Series.__ * * @param name - new name * @see {@link rename} - * */ - alias(name: string): Series; + alias(name: string): Series; /** * __Append a Series to this one.__ * ___ @@ -113,43 +118,41 @@ export interface Series argMin(): Optional; /** * Get index values where Boolean Series evaluate True. - * */ - argTrue(): Series; + argTrue(): Series; /** * Get unique index as Series. */ - argUnique(): Series; + argUnique(): Series; /** * Index location of the sorted variant of this Series. * ___ * @param reverse * @return {SeriesType} indexes - Indexes that can be used to sort this array. */ - argSort(): Series; - argSort(reverse: boolean): Series; - argSort({ reverse }: { reverse: boolean }): Series; + argSort(): Series; + argSort(reverse: boolean): Series; + argSort({ reverse }: { reverse: boolean }): Series; /** * __Rename this Series.__ * * @param name - new name * @see {@link rename} {@link alias} - * */ - as(name: string): Series; + as(name: string): Series; /** * Cast between data types. */ - cast(dtype: DataType, strict?: boolean): Series; + cast(dtype: U, strict?: boolean): Series; /** * Get the length of each individual chunk */ - chunkLengths(): Array; + chunkLengths(): Array; /** * Cheap deep clones. */ - clone(): Series; - concat(other: Series): Series; + clone(): Series; + concat(other: Series): Series; /** * __Quick summary statistics of a series. __ @@ -203,14 +206,14 @@ export interface Series * @param n - number of slots to shift * @param nullBehavior - `'ignore' | 'drop'` */ - diff(n: number, nullBehavior: "ignore" | "drop"): Series; + diff(n: number, nullBehavior: "ignore" | "drop"): Series; diff({ n, nullBehavior, }: { n: number; nullBehavior: "ignore" | "drop"; - }): Series; + }): Series; /** * Compute the dot/inner product between two Series * ___ @@ -226,7 +229,7 @@ export interface Series /** * Create a new Series that copies data from this Series without null values. */ - dropNulls(): Series; + dropNulls(): Series; /** * __Explode a list or utf8 Series.__ * @@ -301,7 +304,6 @@ export interface Series /** * __Filter elements by a boolean mask.__ * @param {SeriesType} predicate - Boolean mask - * */ filter(predicate: Series): Series; filter({ predicate }: { predicate: Series }): Series; @@ -655,7 +657,6 @@ export interface Series /** * Count the null values in this Series. -- * _`undefined` values are treated as null_ - * */ nullCount(): number; /** @@ -744,7 +745,6 @@ export interface Series * - True -> pl.Int64 * - False -> pl.UInt64 * @see {@link cast} - * */ reinterpret(signed?: boolean): Series; /** @@ -1030,7 +1030,6 @@ export interface Series * ___ * @param mask - Boolean Series * @param other - Series of same type - * */ zipWith(mask: Series, other: Series): Series; @@ -1049,7 +1048,7 @@ export interface Series * true * ``` */ - toArray(): Array; + toArray(): Array>; /** * Converts series to a javascript typedArray. * @@ -1060,9 +1059,9 @@ export interface Series /** * Get dummy/indicator variables. - * @param separator: str = "_", + * @param separator: str = "_", * @param dropFirst: bool = False - * + * * @example * const s = pl.Series("a", [1, 2, 3]) >>> s.toDummies() @@ -1088,7 +1087,7 @@ export interface Series │ 1 ┆ 0 │ │ 0 ┆ 1 │ └─────┴─────┘ - * + * */ toDummies(separator?: string, dropFirst?: boolean): DataFrame; @@ -1107,7 +1106,11 @@ export interface Series * } * ``` */ - toObject(): { name: string; datatype: string; values: any[] }; + toObject(): { + name: string; + datatype: DtypeToJsName; + values: DTypeToJs[]; + }; toFrame(): DataFrame; /** compat with `JSON.stringify */ toJSON(): string; @@ -1854,6 +1857,7 @@ export interface SeriesConstructor extends Deserialize { * ] * ``` */ + (values: ArrayLike): Series>; (values: any): Series; /** * Create a new named series @@ -1871,18 +1875,33 @@ export interface SeriesConstructor extends Deserialize { * ] * ``` */ + ( + name: string, + values: ArrayLike, + ): Series>; + ( + name: string, + values: ArrayLike>, + dtype?: T2, + ): Series; (name: string, values: any[], dtype?): Series; /** * Creates an array from an array-like object. * @param arrayLike — An array-like object to convert to an array. */ + from(arrayLike: ArrayLike): Series>; from(arrayLike: ArrayLike): Series; + from( + name: string, + arrayLike: ArrayLike, + ): Series>; from(name: string, arrayLike: ArrayLike): Series; /** * Returns a new Series from a set of elements. * @param items — A set of elements to include in the new Series object. */ + of(...items: T3[]): Series>; of(...items: T3[]): Series; isSeries(arg: any): arg is Series; /**