from collections.abc import Sequence
from functools import partial
import dask.bag as db
from dask.delayed import Delayed
from plateau.core import naming
from plateau.core.docs import default_docs
from plateau.core.factory import DatasetFactory, _ensure_factory
from plateau.core.typing import StoreInput
from plateau.core.utils import lazy_store
from plateau.core.uuid import gen_uuid
from plateau.io.dask._utils import (
_cast_categorical_to_index_cat,
_get_data,
_maybe_get_categoricals_from_index,
)
from plateau.io_components.index import update_indices_from_partitions
from plateau.io_components.metapartition import (
SINGLE_TABLE,
MetaPartition,
parse_input_to_metapartition,
)
from plateau.io_components.read import dispatch_metapartitions_from_factory
from plateau.io_components.utils import normalize_args, raise_if_indices_overlap
from plateau.io_components.write import (
raise_if_dataset_exists,
store_dataset_from_partitions,
)
__all__ = (
"read_dataset_as_dataframe_bag",
"store_bag_as_dataset",
"build_dataset_indices__bag",
)
def _store_dataset_from_partitions_flat(mpss, *args, **kwargs):
return store_dataset_from_partitions(
[mp for sublist in mpss for mp in sublist], *args, **kwargs
)
def _load_and_concat_metapartitions_inner(mps, *args, **kwargs):
return MetaPartition.concat_metapartitions(
[mp.load_dataframes(*args, **kwargs) for mp in mps]
)
@default_docs
def read_dataset_as_metapartitions_bag(
dataset_uuid=None,
store=None,
columns=None,
predicate_pushdown_to_io=True,
categoricals=None,
dates_as_object: bool = True,
predicates=None,
factory=None,
dispatch_by=None,
partition_size=None,
):
"""Retrieve dataset as `dask.bag.Bag` of `MetaPartition` objects.
Parameters
----------
Returns
-------
dask.bag.Bag:
A dask.bag object containing the metapartions.
"""
ds_factory = _ensure_factory(
dataset_uuid=dataset_uuid,
store=store,
factory=factory,
)
store = ds_factory.store_factory
mps = dispatch_metapartitions_from_factory(
dataset_factory=ds_factory,
predicates=predicates,
dispatch_by=dispatch_by,
)
mp_bag = db.from_sequence(mps, partition_size=partition_size)
if dispatch_by is not None:
mp_bag = mp_bag.map(
_load_and_concat_metapartitions_inner,
store=store,
columns=columns,
categoricals=categoricals,
predicate_pushdown_to_io=predicate_pushdown_to_io,
dates_as_object=dates_as_object,
predicates=predicates,
)
else:
mp_bag = mp_bag.map(
MetaPartition.load_dataframes,
store=store,
columns=columns,
categoricals=categoricals,
predicate_pushdown_to_io=predicate_pushdown_to_io,
dates_as_object=dates_as_object,
predicates=predicates,
)
categoricals_from_index = _maybe_get_categoricals_from_index(
ds_factory, categoricals
)
if categoricals_from_index:
mp_bag = mp_bag.map(
MetaPartition.apply,
func=partial(
_cast_categorical_to_index_cat, categories=categoricals_from_index
),
type_safe=True,
)
return mp_bag
[docs]
@default_docs
def read_dataset_as_dataframe_bag(
dataset_uuid=None,
store=None,
columns=None,
predicate_pushdown_to_io=True,
categoricals=None,
dates_as_object: bool = True,
predicates=None,
factory=None,
dispatch_by=None,
partition_size=None,
):
"""Retrieve data as dataframe from a :class:`dask.bag.Bag` of
`MetaPartition` objects.
Parameters
----------
Returns
-------
dask.bag.Bag
A dask.bag.Bag which contains the metapartitions and mapped to a function for retrieving the data.
"""
mps = read_dataset_as_metapartitions_bag(
dataset_uuid=dataset_uuid,
store=store,
factory=factory,
columns=columns,
predicate_pushdown_to_io=predicate_pushdown_to_io,
categoricals=categoricals,
dates_as_object=dates_as_object,
predicates=predicates,
dispatch_by=dispatch_by,
partition_size=partition_size,
)
return mps.map(_get_data)
[docs]
@default_docs
@normalize_args
def store_bag_as_dataset(
bag,
store,
dataset_uuid=None,
metadata=None,
df_serializer=None,
overwrite=False,
metadata_merger=None,
metadata_version=naming.DEFAULT_METADATA_VERSION,
partition_on=None,
metadata_storage_format=naming.DEFAULT_METADATA_STORAGE_FORMAT,
secondary_indices=None,
table_name: str = SINGLE_TABLE,
):
"""Transform and store a dask.bag of dictionaries containing dataframes to
a plateau dataset in store.
This is the dask.bag-equivalent of
:func:`~plateau.io.dask.delayed.store_delayed_as_dataset`. See there
for more detailed documentation on the different possible input types.
Parameters
----------
bag: dask.bag.Bag
A dask bag containing dictionaries of dataframes or dataframes.
"""
store = lazy_store(store)
if dataset_uuid is None:
dataset_uuid = gen_uuid()
if not overwrite:
raise_if_dataset_exists(dataset_uuid=dataset_uuid, store=store)
raise_if_indices_overlap(partition_on, secondary_indices)
input_to_mps = partial(
parse_input_to_metapartition,
metadata_version=metadata_version,
table_name=table_name,
)
mps = bag.map(input_to_mps)
if partition_on:
mps = mps.map(MetaPartition.partition_on, partition_on=partition_on)
if secondary_indices:
mps = mps.map(MetaPartition.build_indices, columns=secondary_indices)
mps = mps.map(
MetaPartition.store_dataframes,
store=store,
df_serializer=df_serializer,
dataset_uuid=dataset_uuid,
)
aggregate = partial(
_store_dataset_from_partitions_flat,
dataset_uuid=dataset_uuid,
store=store,
dataset_metadata=metadata,
metadata_merger=metadata_merger,
metadata_storage_format=metadata_storage_format,
)
return mps.reduction(perpartition=list, aggregate=aggregate, split_every=False)
[docs]
@default_docs
def build_dataset_indices__bag(
store: StoreInput | None,
dataset_uuid: str | None,
columns: Sequence[str],
partition_size: int | None = None,
factory: DatasetFactory | None = None,
) -> Delayed:
"""Function which builds a
:class:`~plateau.core.index.ExplicitSecondaryIndex`.
This function loads the dataset, computes the requested indices and writes
the indices to the dataset. The dataset partitions itself are not mutated.
Parameters
----------
"""
ds_factory = _ensure_factory(
dataset_uuid=dataset_uuid,
store=store,
factory=factory,
)
assert ds_factory.schema is not None
cols_to_load = set(columns) & set(ds_factory.schema.names)
mps = dispatch_metapartitions_from_factory(ds_factory)
return (
db.from_sequence(seq=mps, partition_size=partition_size)
.map(
MetaPartition.load_dataframes,
store=ds_factory.store_factory,
columns=cols_to_load,
)
.map(MetaPartition.build_indices, columns=columns)
.map(MetaPartition.remove_dataframes)
.reduction(list, list, split_every=False, out_type=db.Bag)
.flatten()
.map_partitions(list)
.map_partitions(
update_indices_from_partitions, dataset_metadata_factory=ds_factory
)
)