Skip to content

Commit 4218a15

Browse files
committed
Major improvement to the performance of filter.
Rather than building a set (which is both memory and compute inefficient) use findIndices to identify all the indexes that satisfy a predicate then select those in the filter.
1 parent 80ebd0a commit 4218a15

4 files changed

Lines changed: 31 additions & 5 deletions

File tree

benchmark/Main.hs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ import qualified Data.Vector.Unboxed as VU
66

77
import Control.Monad (replicateM)
88
import Criterion.Main
9+
import Data.Time
910
import System.Random (randomRIO)
1011

1112
stats :: Int -> IO ()
1213
stats n = do
14+
startTime <- getCurrentTime
1315
ns <- do
1416
ns' <- VU.replicateM n (randomRIO (-20.0 :: Double, 20.0))
1517
pure $ replicate 3 ns'
@@ -19,6 +21,9 @@ stats n = do
1921
print $ D.variance "1" df
2022
print $ D.correlation "1" "2" df
2123
print $ D.filter "0" (>= (19.9 :: Double)) df D.|> D.take 10
24+
endTime <- getCurrentTime
25+
let diff = diffUTCTime endTime startTime
26+
putStrLn $ "Execution Time: " ++ (show diff)
2227

2328
main = defaultMain [
2429
bgroup "stats" [ bench "300_000" $ nfIO (stats 100_000)

dataframe.cabal

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ source-repository head
2222
location: https://github.com/mchav/dataframe
2323

2424
library
25-
default-extensions: StrictData
25+
-- default-extensions: StrictData
2626
exposed-modules: DataFrame,
2727
DataFrame.Lazy
2828
other-modules: DataFrame.Internal.Types,
@@ -113,6 +113,7 @@ benchmark dataframe-benchmark
113113
build-depends: base >= 4.17.2.0 && < 4.22,
114114
criterion >= 1 && <= 1.6.4.0,
115115
text >= 2.0 && <= 2.1.2,
116+
time >= 1.12,
116117
random >= 1 && <= 1.3.1,
117118
vector ^>= 0.13,
118119
dataframe

src/DataFrame/Internal/Column.hs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,26 @@ getIndicesUnboxed :: (VU.Unbox a) => VU.Vector Int -> VU.Vector a -> VU.Vector a
378378
getIndicesUnboxed indices xs = VU.generate (VU.length indices) (\i -> xs VU.! (indices VU.! i))
379379
{-# INLINE getIndicesUnboxed #-}
380380

381+
findIndices :: forall a. (Columnable a)
382+
=> (a -> Bool)
383+
-> Column
384+
-> Maybe (VU.Vector Int)
385+
findIndices pred (BoxedColumn (column :: VB.Vector b)) = do
386+
Refl <- testEquality (typeRep @a) (typeRep @b)
387+
pure $ VG.convert (VG.findIndices pred column)
388+
findIndices pred (UnboxedColumn (column :: VU.Vector b)) = do
389+
Refl <- testEquality (typeRep @a) (typeRep @b)
390+
pure $ VG.findIndices pred column
391+
findIndices pred (OptionalColumn (column :: VB.Vector (Maybe b))) = do
392+
Refl <- testEquality (typeRep @a) (typeRep @(Maybe b))
393+
pure $ VG.convert (VG.findIndices pred column)
394+
findIndices pred (GroupedBoxedColumn (column :: VB.Vector b)) = do
395+
Refl <- testEquality (typeRep @a) (typeRep @b)
396+
pure $ VG.convert (VG.findIndices pred column)
397+
findIndices pred (GroupedUnboxedColumn (column :: VB.Vector b)) = do
398+
Refl <- testEquality (typeRep @a) (typeRep @b)
399+
pure $ VG.convert (VG.findIndices pred column)
400+
381401
-- | An internal function that returns a vector of how indexes change after a column is sorted.
382402
sortedIndexes :: Bool -> Column -> VU.Vector Int
383403
sortedIndexes asc (BoxedColumn column ) = runST $ do

src/DataFrame/Operations/Subset.hs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,16 @@ filter ::
7878
DataFrame
7979
filter filterColumnName condition df = case getColumn filterColumnName df of
8080
Nothing -> throw $ ColumnNotFoundException filterColumnName "filter" (map fst $ M.toList $ columnIndices df)
81-
Just column -> case ifoldlColumn (\s i v -> if condition v then S.insert i s else s) S.empty column of
81+
Just column -> case findIndices condition column of
8282
Nothing -> throw $ TypeMismatchException (MkTypeErrorContext
8383
{ userType = Right $ typeRep @a
8484
, expectedType = Left (columnTypeString column) :: Either String (TypeRep ())
8585
, errorColumnName = Just (T.unpack filterColumnName)
8686
, callingFunctionName = Just "filter"})
8787
Just indexes -> let
8888
c' = snd $ dataframeDimensions df
89-
pick idxs col = atIndices idxs col
90-
in df {columns = V.map (pick indexes) (columns df), dataframeDimensions = (S.size indexes, c')}
89+
pick idxs col = atIndicesStable idxs col
90+
in df {columns = V.map (pick indexes) (columns df), dataframeDimensions = (VG.length indexes, c')}
9191

9292
-- | O(k) a version of filter where the predicate comes first.
9393
--
@@ -102,7 +102,7 @@ filterBy = flip filter
102102
filterWhere :: Expr Bool -> DataFrame -> DataFrame
103103
filterWhere expr df = let
104104
(TColumn col) = interpret @Bool df expr
105-
(Just indexes) = VU.convert . V.map (fromMaybe 0) . V.filter isJust . toVector @(Maybe Int) <$> imapColumn (\i satisfied -> if satisfied then Just i else Nothing) col
105+
(Just indexes) = findIndices (==True) col
106106
c' = snd $ dataframeDimensions df
107107
pick idxs col = atIndicesStable idxs col
108108
in df {columns = V.map (pick indexes) (columns df), dataframeDimensions = (VU.length indexes, c')}

0 commit comments

Comments
 (0)