Current File : //usr/local/lib64/python3.6/site-packages/pandas/core/window/common.py |
"""Common utility functions for rolling operations"""
from collections import defaultdict
from typing import Callable, Optional
import warnings
import numpy as np
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
from pandas.core.generic import _shared_docs
from pandas.core.groupby.base import GroupByMixin
from pandas.core.indexes.api import MultiIndex
_shared_docs = dict(**_shared_docs)
_doc_template = """
Returns
-------
Series or DataFrame
Return type is determined by the caller.
See Also
--------
pandas.Series.%(name)s : Calling object with Series data.
pandas.DataFrame.%(name)s : Calling object with DataFrame data.
pandas.Series.%(func_name)s : Similar method for Series.
pandas.DataFrame.%(func_name)s : Similar method for DataFrame.
"""
def _dispatch(name: str, *args, **kwargs):
"""
Dispatch to apply.
"""
def outer(self, *args, **kwargs):
def f(x):
x = self._shallow_copy(x, groupby=self._groupby)
return getattr(x, name)(*args, **kwargs)
return self._groupby.apply(f)
outer.__name__ = name
return outer
class WindowGroupByMixin(GroupByMixin):
"""
Provide the groupby facilities.
"""
def __init__(self, obj, *args, **kwargs):
kwargs.pop("parent", None)
groupby = kwargs.pop("groupby", None)
if groupby is None:
groupby, obj = obj, obj._selected_obj
self._groupby = groupby
self._groupby.mutated = True
self._groupby.grouper.mutated = True
super().__init__(obj, *args, **kwargs)
count = _dispatch("count")
corr = _dispatch("corr", other=None, pairwise=None)
cov = _dispatch("cov", other=None, pairwise=None)
def _apply(
self,
func: Callable,
center: bool,
require_min_periods: int = 0,
floor: int = 1,
is_weighted: bool = False,
name: Optional[str] = None,
use_numba_cache: bool = False,
**kwargs,
):
"""
Dispatch to apply; we are stripping all of the _apply kwargs and
performing the original function call on the grouped object.
"""
kwargs.pop("floor", None)
kwargs.pop("original_func", None)
# TODO: can we de-duplicate with _dispatch?
def f(x, name=name, *args):
x = self._shallow_copy(x)
if isinstance(name, str):
return getattr(x, name)(*args, **kwargs)
return x.apply(name, *args, **kwargs)
return self._groupby.apply(f)
def _flex_binary_moment(arg1, arg2, f, pairwise=False):
if not (
isinstance(arg1, (np.ndarray, ABCSeries, ABCDataFrame))
and isinstance(arg2, (np.ndarray, ABCSeries, ABCDataFrame))
):
raise TypeError(
"arguments to moment function must be of type np.ndarray/Series/DataFrame"
)
if isinstance(arg1, (np.ndarray, ABCSeries)) and isinstance(
arg2, (np.ndarray, ABCSeries)
):
X, Y = prep_binary(arg1, arg2)
return f(X, Y)
elif isinstance(arg1, ABCDataFrame):
from pandas import DataFrame
def dataframe_from_int_dict(data, frame_template):
result = DataFrame(data, index=frame_template.index)
if len(result.columns) > 0:
result.columns = frame_template.columns[result.columns]
return result
results = {}
if isinstance(arg2, ABCDataFrame):
if pairwise is False:
if arg1 is arg2:
# special case in order to handle duplicate column names
for i, col in enumerate(arg1.columns):
results[i] = f(arg1.iloc[:, i], arg2.iloc[:, i])
return dataframe_from_int_dict(results, arg1)
else:
if not arg1.columns.is_unique:
raise ValueError("'arg1' columns are not unique")
if not arg2.columns.is_unique:
raise ValueError("'arg2' columns are not unique")
with warnings.catch_warnings(record=True):
warnings.simplefilter("ignore", RuntimeWarning)
X, Y = arg1.align(arg2, join="outer")
X = X + 0 * Y
Y = Y + 0 * X
with warnings.catch_warnings(record=True):
warnings.simplefilter("ignore", RuntimeWarning)
res_columns = arg1.columns.union(arg2.columns)
for col in res_columns:
if col in X and col in Y:
results[col] = f(X[col], Y[col])
return DataFrame(results, index=X.index, columns=res_columns)
elif pairwise is True:
results = defaultdict(dict)
for i, k1 in enumerate(arg1.columns):
for j, k2 in enumerate(arg2.columns):
if j < i and arg2 is arg1:
# Symmetric case
results[i][j] = results[j][i]
else:
results[i][j] = f(
*prep_binary(arg1.iloc[:, i], arg2.iloc[:, j])
)
from pandas import concat
result_index = arg1.index.union(arg2.index)
if len(result_index):
# construct result frame
result = concat(
[
concat(
[results[i][j] for j, c in enumerate(arg2.columns)],
ignore_index=True,
)
for i, c in enumerate(arg1.columns)
],
ignore_index=True,
axis=1,
)
result.columns = arg1.columns
# set the index and reorder
if arg2.columns.nlevels > 1:
result.index = MultiIndex.from_product(
arg2.columns.levels + [result_index]
)
# GH 34440
num_levels = len(result.index.levels)
new_order = [num_levels - 1] + list(range(num_levels - 1))
result = result.reorder_levels(new_order).sort_index()
else:
result.index = MultiIndex.from_product(
[range(len(arg2.columns)), range(len(result_index))]
)
result = result.swaplevel(1, 0).sort_index()
result.index = MultiIndex.from_product(
[result_index] + [arg2.columns]
)
else:
# empty result
result = DataFrame(
index=MultiIndex(
levels=[arg1.index, arg2.columns], codes=[[], []]
),
columns=arg2.columns,
dtype="float64",
)
# reset our index names to arg1 names
# reset our column names to arg2 names
# careful not to mutate the original names
result.columns = result.columns.set_names(arg1.columns.names)
result.index = result.index.set_names(
result_index.names + arg2.columns.names
)
return result
else:
raise ValueError("'pairwise' is not True/False")
else:
results = {
i: f(*prep_binary(arg1.iloc[:, i], arg2))
for i, col in enumerate(arg1.columns)
}
return dataframe_from_int_dict(results, arg1)
else:
return _flex_binary_moment(arg2, arg1, f)
def zsqrt(x):
with np.errstate(all="ignore"):
result = np.sqrt(x)
mask = x < 0
if isinstance(x, ABCDataFrame):
if mask._values.any():
result[mask] = 0
else:
if mask.any():
result[mask] = 0
return result
def prep_binary(arg1, arg2):
if not isinstance(arg2, type(arg1)):
raise Exception("Input arrays must be of the same type!")
# mask out values, this also makes a common index...
X = arg1 + 0 * arg2
Y = arg2 + 0 * arg1
return X, Y