Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion python/datafusion/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@
from datafusion._internal import DataFrame as DataFrameInternal
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
from datafusion.expr import Expr, SortExpr, sort_or_default
from datafusion.expr import Expr, SortExpr, sort_or_default, Window
from datafusion.plan import ExecutionPlan, LogicalPlan
from datafusion.record_batch import RecordBatchStream
from datafusion.functions import col, nvl, last_value
from datafusion.common import NullTreatment

if TYPE_CHECKING:
import pathlib
Expand Down Expand Up @@ -360,6 +362,9 @@ def describe(self) -> DataFrame:
"""
return DataFrame(self.df.describe())

@deprecated(
"schema() is deprecated. Use :py:meth:`~DataFrame.get_schema` instead"
)
def schema(self) -> pa.Schema:
"""Return the :py:class:`pyarrow.Schema` of this DataFrame.

Expand All @@ -370,6 +375,39 @@ def schema(self) -> pa.Schema:
Describing schema of the DataFrame
"""
return self.df.schema()

def to_batches(self) -> list[pa.RecordBatch]:
"""Convert DataFrame to list of RecordBatches."""
return self.collect() # delegate to existing method

def interpolate(self, method: str = "forward_fill", **kwargs) -> DataFrame:
"""Interpolate missing values per column.

Args:
method: Interpolation method ('linear', 'forward_fill', 'backward_fill')

Returns:
DataFrame with interpolated values

Raises:
NotImplementedError: Linear interpolation not yet supported
"""
if method == "forward_fill":
exprs = []
for field in self.schema():
window = Window(order_by=col(field.name))
expr = nvl(col(field.name),last_value(col(field.name)).over(window)).alias(field.name)
exprs.append(expr)
return self.select(*exprs)

elif method == "backward_fill":
raise NotImplementedError("backward_fill not yet implemented")

elif method == "linear":
raise NotImplementedError("Linear interpolation requires complex window function logic")

else:
raise ValueError(f"Unknown interpolation method: {method}")

@deprecated(
"select_columns() is deprecated. Use :py:meth:`~DataFrame.select` instead"
Expand Down Expand Up @@ -592,6 +630,9 @@ def tail(self, n: int = 5) -> DataFrame:
"""
return DataFrame(self.df.limit(n, max(0, self.count() - n)))

@deprecated(
"collect() returning RecordBatch list is deprecated. Use to_batches() for RecordBatch list or collect() will return DataFrame in future versions"
)
def collect(self) -> list[pa.RecordBatch]:
"""Execute this :py:class:`DataFrame` and collect results into memory.

Expand Down
41 changes: 41 additions & 0 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,47 @@ def get_header_style(self) -> str:
"padding: 10px; border: 1px solid #3367d6;"
)

def test_to_batches(df):
"""Test to_batches method returns list of RecordBatches."""
batches = df.to_batches()
assert isinstance(batches, list)
assert len(batches) > 0
assert all(isinstance(batch, pa.RecordBatch) for batch in batches)


collect_batches = df.collect()
assert len(batches) == len(collect_batches)
for i, batch in enumerate(batches):
assert batch.equals(collect_batches[i])


def test_interpolate_forward_fill(ctx):
"""Test interpolate method with forward_fill."""

batch = pa.RecordBatch.from_arrays(
[pa.array([1, None, 3, None]), pa.array([4.0, None, 6.0, None])],
names=["int_col", "float_col"],
)
df = ctx.create_dataframe([[batch]])

result = df.interpolate("forward_fill")

assert isinstance(result, DataFrame)


def test_interpolate_unsupported_method(ctx):
"""Test interpolate with unsupported method raises error."""
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3])], names=["a"]
)
df = ctx.create_dataframe([[batch]])

with pytest.raises(NotImplementedError, match="requires complex window"):
df.interpolate("linear")

with pytest.raises(ValueError, match="Unknown interpolation method"):
df.interpolate("unknown")


def count_table_rows(html_content: str) -> int:
"""Count the number of table rows in HTML content.
Expand Down
Loading