|
| 1 | +.. Licensed to the Apache Software Foundation (ASF) under one |
| 2 | + or more contributor license agreements. See the NOTICE file |
| 3 | + distributed with this work for additional information |
| 4 | + regarding copyright ownership. The ASF licenses this file |
| 5 | + to you under the Apache License, Version 2.0 (the |
| 6 | + "License"); you may not use this file except in compliance |
| 7 | + with the License. You may obtain a copy of the License at |
| 8 | +
|
| 9 | +.. http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +
|
| 11 | +.. Unless required by applicable law or agreed to in writing, |
| 12 | + software distributed under the License is distributed on an |
| 13 | + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | + KIND, either express or implied. See the License for the |
| 15 | + specific language governing permissions and limitations |
| 16 | + under the License. |
| 17 | +
|
| 18 | +====================================================== |
| 19 | +Vectorized Python User-defined Table Functions (UDTFs) |
| 20 | +====================================================== |
| 21 | + |
| 22 | +Spark 4.1 introduces the Vectorized Python user-defined table function (UDTF), a new type of user-defined table-valued function. |
| 23 | +It can be used via the ``@arrow_udtf`` decorator. |
| 24 | +Unlike scalar functions that return a single result value from each call, each UDTF is invoked in |
| 25 | +the ``FROM`` clause of a query and returns an entire table as output. |
| 26 | +Unlike the traditional Python UDTF that evaluates row by row, the Vectorized Python UDTF lets you directly operate on top of Apache Arrow arrays and column batches. |
| 27 | +This allows you to leverage vectorized operations and improve the performance of your UDTF. |
| 28 | + |
| 29 | +Vectorized Python UDTF Interface |
| 30 | +-------------------------------- |
| 31 | + |
| 32 | +.. currentmodule:: pyspark.sql.functions |
| 33 | + |
| 34 | +.. code-block:: python |
| 35 | +
|
| 36 | + class NameYourArrowPythonUDTF: |
| 37 | +
|
| 38 | + def __init__(self) -> None: |
| 39 | + """ |
| 40 | + Initializes the user-defined table function (UDTF). This is optional. |
| 41 | +
|
| 42 | + This method serves as the default constructor and is called once when the |
| 43 | + UDTF is instantiated on the executor side. |
| 44 | +
|
| 45 | + Any class fields assigned in this method will be available for subsequent |
| 46 | + calls to the `eval`, `terminate` and `cleanup` methods. |
| 47 | +
|
| 48 | + Notes |
| 49 | + ----- |
| 50 | + - You cannot create or reference the Spark session within the UDTF. Any |
| 51 | + attempt to do so will result in a serialization error. |
| 52 | + """ |
| 53 | + ... |
| 54 | +
|
| 55 | + def eval(self, *args: Any) -> Iterator[pa.RecordBatch | pa.Table]: |
| 56 | + """ |
| 57 | + Evaluates the function using the given input arguments. |
| 58 | +
|
| 59 | + This method is required and must be implemented. |
| 60 | +
|
| 61 | + Argument Mapping: |
| 62 | + - Each provided scalar expression maps to exactly one value in the |
| 63 | + `*args` list with type `pa.Array`. |
| 64 | + - Each provided table argument maps to a `pa.RecordBatch` object containing |
| 65 | + the columns in the order they appear in the provided input table, |
| 66 | + and with the names computed by the query analyzer. |
| 67 | +
|
| 68 | + This method is called on every batch of input rows, and can produce zero or more |
| 69 | + output pyarrow record batches or pyarrow tables. Each element in the output tuple |
| 70 | + corresponds to one column specified in the return type of the UDTF. |
| 71 | +
|
| 72 | + Parameters |
| 73 | + ---------- |
| 74 | + *args : Any |
| 75 | + Arbitrary positional arguments representing the input to the UDTF. |
| 76 | +
|
| 77 | + Yields |
| 78 | + ------ |
| 79 | + iterator |
| 80 | + An iterator of `pa.RecordBatch` or `pa.Table` objects representing a batch of rows |
| 81 | + in the UDTF result table. Yield as many times as needed to produce multiple batches. |
| 82 | +
|
| 83 | + Notes |
| 84 | + ----- |
| 85 | + - UDTFs can instead accept keyword arguments during the function call if needed. |
| 86 | + - The `eval` method can raise a `SkipRestOfInputTableException` to indicate that the |
| 87 | + UDTF wants to skip consuming all remaining rows from the current partition of the |
| 88 | + input table. This will cause the UDTF to proceed directly to the `terminate` method. |
| 89 | + - The `eval` method can raise any other exception to indicate that the UDTF should be |
| 90 | + aborted entirely. This will cause the UDTF to skip the `terminate` method and proceed |
| 91 | + directly to the `cleanup` method, and then the exception will be propagated to the |
| 92 | + query processor causing the invoking query to fail. |
| 93 | +
|
| 94 | + Examples |
| 95 | + -------- |
| 96 | + This `eval` method takes a table argument and returns an arrow record batch for each input batch. |
| 97 | +
|
| 98 | + >>> def eval(self, batch: pa.RecordBatch): |
| 99 | + ... yield batch |
| 100 | +
|
| 101 | + This `eval` method takes a table argument and returns a pyarrow table for each input batch. |
| 102 | +
|
| 103 | + >>> def eval(self, batch: pa.RecordBatch): |
| 104 | + ... yield pa.table({"x": batch.column(0), "y": batch.column(1)}) |
| 105 | +
|
| 106 | + This `eval` method takes both table and scalar arguments and returns a pyarrow table for each input batch. |
| 107 | +
|
| 108 | + >>> def eval(self, batch: pa.RecordBatch, x: pa.Array): |
| 109 | + ... yield pa.table({"x": x, "y": batch.column(0)}) |
| 110 | + """ |
| 111 | + ... |
| 112 | +
|
| 113 | + def terminate(self) -> Iterator[pa.RecordBatch | pa.Table]: |
| 114 | + """ |
| 115 | + Called when the UDTF has successfully processed all input rows. |
| 116 | +
|
| 117 | + This method is optional to implement and is useful for performing any |
| 118 | + finalization operations after the UDTF has finished processing |
| 119 | + all rows. It can also be used to yield additional rows if needed. |
| 120 | + Table functions that consume all rows in the entire input partition |
| 121 | + and then compute and return the entire output table can do so from |
| 122 | + this method as well (please be mindful of memory usage when doing |
| 123 | + this). |
| 124 | +
|
| 125 | + If any exceptions occur during input row processing, this method |
| 126 | + won't be called. |
| 127 | +
|
| 128 | + Yields |
| 129 | + ------ |
| 130 | + iterator |
| 131 | + An iterator of `pa.RecordBatch` or `pa.Table` objects representing a batch of rows |
| 132 | + in the UDTF result table. Yield as many times as needed to produce multiple batches. |
| 133 | +
|
| 134 | + Examples |
| 135 | + -------- |
| 136 | + >>> def terminate(self) -> Iterator[pa.RecordBatch | pa.Table]: |
| 137 | + >>> yield pa.table({"x": pa.array([1, 2, 3])}) |
| 138 | + """ |
| 139 | + ... |
| 140 | +
|
| 141 | + def cleanup(self) -> None: |
| 142 | + """ |
| 143 | + Invoked after the UDTF completes processing input rows. |
| 144 | +
|
| 145 | + This method is optional to implement and is useful for final cleanup |
| 146 | + regardless of whether the UDTF processed all input rows successfully |
| 147 | + or was aborted due to exceptions. |
| 148 | +
|
| 149 | + Examples |
| 150 | + -------- |
| 151 | + >>> def cleanup(self) -> None: |
| 152 | + >>> self.conn.close() |
| 153 | + """ |
| 154 | + ... |
| 155 | +
|
| 156 | +Defining the Output Schema |
| 157 | +-------------------------- |
| 158 | + |
| 159 | +The return type of the UDTF defines the schema of the table it outputs. |
| 160 | +You can specify it in the ``@arrow_udtf`` decorator. |
| 161 | + |
| 162 | +It must be either a ``StructType``: |
| 163 | + |
| 164 | +.. code-block:: python |
| 165 | +
|
| 166 | + @arrow_udtf(returnType=StructType().add("c1", StringType()).add("c2", IntegerType())) |
| 167 | + class YourArrowPythonUDTF: |
| 168 | + ... |
| 169 | +
|
| 170 | +or a DDL string representing a struct type: |
| 171 | + |
| 172 | +.. code-block:: python |
| 173 | +
|
| 174 | + @arrow_udtf(returnType="c1 string, c2 int") |
| 175 | + class YourArrowPythonUDTF: |
| 176 | + ... |
| 177 | +
|
| 178 | +Emitting Output Rows |
| 179 | +-------------------- |
| 180 | + |
| 181 | +The `eval` and `terminate` methods then emit zero or more output batches conforming to this schema by |
| 182 | +yielding ``pa.RecordBatch`` or ``pa.Table`` objects. |
| 183 | + |
| 184 | +.. code-block:: python |
| 185 | +
|
| 186 | + @arrow_udtf(returnType="c1 int, c2 int") |
| 187 | + class YourArrowPythonUDTF: |
| 188 | + def eval(self, batch: pa.RecordBatch): |
| 189 | + yield pa.table({"c1": batch.column(0), "c2": batch.column(1)}) |
| 190 | +
|
| 191 | +You can also yield multiple pyarrow tables in the `eval` method. |
| 192 | + |
| 193 | +.. code-block:: python |
| 194 | +
|
| 195 | + @arrow_udtf(returnType="c1 int") |
| 196 | + class YourArrowPythonUDTF: |
| 197 | + def eval(self, batch: pa.RecordBatch): |
| 198 | + yield pa.table({"c1": batch.column(0)}) |
| 199 | + yield pa.table({"c1": batch.column(1)}) |
| 200 | +
|
| 201 | +You can also yield multiple pyarrow record batches in the `eval` method. |
| 202 | + |
| 203 | +.. code-block:: python |
| 204 | +
|
| 205 | + @arrow_udtf(returnType="c1 int") |
| 206 | + class YourArrowPythonUDTF: |
| 207 | + def eval(self, batch: pa.RecordBatch): |
| 208 | + new_batch = pa.record_batch( |
| 209 | + {"c1": batch.column(0).slice(0, len(batch) // 2)}) |
| 210 | + yield new_batch |
| 211 | +
|
| 212 | +
|
| 213 | +Usage Examples |
| 214 | +-------------- |
| 215 | + |
| 216 | +Here's how to use these UDTFs in DataFrame: |
| 217 | + |
| 218 | +.. code-block:: python |
| 219 | +
|
| 220 | + import pyarrow as pa |
| 221 | + from pyspark.sql.functions import arrow_udtf |
| 222 | +
|
| 223 | + @arrow_udtf(returnType="c1 string") |
| 224 | + class MyArrowPythonUDTF: |
| 225 | + def eval(self, batch: pa.RecordBatch): |
| 226 | + yield pa.table({"c1": batch.column("value")}) |
| 227 | +
|
| 228 | + df = spark.range(10).selectExpr("id", "cast(id as string) as value") |
| 229 | + MyArrowPythonUDTF(df.asTable()).show() |
| 230 | +
|
| 231 | + # Register the UDTF |
| 232 | + spark.udtf.register("my_arrow_udtf", MyArrowPythonUDTF) |
| 233 | +
|
| 234 | + # Use in SQL queries |
| 235 | + df = spark.sql(""" |
| 236 | + SELECT * FROM my_arrow_udtf(TABLE(SELECT id, cast(id as string) as value FROM range(10))) |
| 237 | + """) |
0 commit comments