Skip to content

Commit d8b9f59

Browse files
committed
Slight improvement to the performance and readability of groupBy
1 parent 533f561 commit d8b9f59

1 file changed

Lines changed: 18 additions & 55 deletions

File tree

src/DataFrame/Operations/Aggregation.hs

Lines changed: 18 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import qualified Data.Vector.Generic as VG
1717
import qualified Data.Vector as V
1818
import qualified Data.Vector.Mutable as VM
1919
import qualified Data.Vector.Unboxed as VU
20+
import qualified Data.Vector.Algorithms.Merge as VA
2021
import qualified Statistics.Quantile as SS
2122
import qualified Statistics.Sample as SS
2223

@@ -32,6 +33,7 @@ import DataFrame.Operations.Core
3233
import DataFrame.Operations.Subset
3334
import Data.Function ((&))
3435
import Data.Hashable
36+
import qualified Data.HashTable.ST.Basic as H
3537
import Data.Maybe
3638
import Data.Type.Equality (type (:~:)(Refl), TestEquality(..))
3739
import Type.Reflection (typeRep, typeOf)
@@ -46,76 +48,37 @@ groupBy names df
4648
| any (`notElem` columnNames df) names = throw $ ColumnNotFoundException (T.pack $ show $ names L.\\ columnNames df) "groupBy" (columnNames df)
4749
| otherwise = L.foldl' insertColumns initDf groupingColumns
4850
where
49-
insertOrAdjust k v m = if MS.notMember k m then MS.insert k [v] m else MS.adjust (appendWithFrontMin v) k m
50-
-- Create a string representation of each row.
51-
values = V.generate (fst (dimensions df)) (mkRowRep df (S.fromList names))
52-
-- Create a mapping from the row representation to the list of indices that
53-
-- have that row representation. This will allow us sortedIndexesto combine the indexes
54-
-- where the rows are the same.
55-
valueIndices = V.ifoldl' (\m index rowRep -> insertOrAdjust rowRep index m) M.empty values
56-
-- Since the min is at the head this allows us to get the min in constant time and sort by it
57-
-- That way we can recover the original order of the rows.
58-
-- valueIndicesInitOrder = L.sortBy (compare `on` snd) $! MS.toList $ MS.map VU.head valueIndices
59-
valueIndicesInitOrder = runST $ do
60-
v <- VM.new (MS.size valueIndices)
61-
foldM_ (\i idxs -> VM.write v i (VU.fromList idxs) >> return (i + 1)) 0 valueIndices
62-
V.unsafeFreeze v
51+
indicesToGroup = M.elems $ M.filterWithKey (\k _ -> k `elem` names) (columnIndices df)
52+
rowRepresentations = VU.generate (fst (dimensions df)) (mkRowRep indicesToGroup df)
53+
54+
valueIndices = V.fromList $ map (VG.map fst) $ VG.groupBy (\a b -> snd a == snd b) (runST $ do
55+
withIndexes <- VG.thaw $ VG.indexed rowRepresentations
56+
VA.sortBy (\(a, b) (a', b') -> compare b b') withIndexes
57+
VG.unsafeFreeze withIndexes)
6358

6459
-- These are the indexes of the grouping/key rows i.e the minimum elements
6560
-- of the list.
66-
keyIndices = VU.generate (VG.length valueIndicesInitOrder) (\i -> VG.head $ valueIndicesInitOrder VG.! i)
61+
keyIndices = VU.generate (VG.length valueIndices) (\i -> VG.minimum $ valueIndices VG.! i)
6762
-- this will be our main worker function in the fold that takes all
6863
-- indices and replaces each value in a column with a list of
6964
-- the elements with the indices where the grouped row
7065
-- values are the same.
71-
insertColumns = groupColumns valueIndicesInitOrder df
66+
insertColumns = groupColumns valueIndices df
7267
-- Out initial DF will just be all the grouped rows added to an
7368
-- empty dataframe. The entries are dedued and are in their
7469
-- initial order.
7570
initDf = L.foldl' (mkGroupedColumns keyIndices df) empty names
7671
-- All the rest of the columns that we are grouping by.
7772
groupingColumns = columnNames df L.\\ names
7873

79-
mkRowRep :: DataFrame -> S.Set T.Text -> Int -> Int
80-
mkRowRep df names i = hash $ V.ifoldl' go [] (columns df)
74+
mkRowRep :: [Int] -> DataFrame -> Int -> Int
75+
mkRowRep groupColumnIndices df i = hash (map mkHash groupColumnIndices)
8176
where
82-
indexMap = M.fromList (map (\(a, b) -> (b, a)) $ M.toList (columnIndices df))
83-
go acc k (BoxedColumn (c :: V.Vector a)) =
84-
if S.notMember (indexMap M.! k) names
85-
then acc
86-
else case c V.!? i of
87-
Just e -> hash' @a e : acc
88-
Nothing ->
89-
error $
90-
"Column "
91-
++ T.unpack (indexMap M.! k)
92-
++ " has less items than "
93-
++ "the other columns at index "
94-
++ show i
95-
go acc k (OptionalColumn (c :: V.Vector (Maybe a))) =
96-
if S.notMember (indexMap M.! k) names
97-
then acc
98-
else case c V.!? i of
99-
Just e -> hash' @(Maybe a) e : acc
100-
Nothing ->
101-
error $
102-
"Column "
103-
++ T.unpack (indexMap M.! k)
104-
++ " has less items than "
105-
++ "the other columns at index "
106-
++ show i
107-
go acc k (UnboxedColumn (c :: VU.Vector a)) =
108-
if S.notMember (indexMap M.! k) names
109-
then acc
110-
else case c VU.!? i of
111-
Just e -> hash' @a e : acc
112-
Nothing ->
113-
error $
114-
"Column "
115-
++ T.unpack (indexMap M.! k)
116-
++ " has less items than "
117-
++ "the other columns at index "
118-
++ show i
77+
getHashedElem (BoxedColumn (c :: V.Vector a)) j = hash' @a (c V.! j)
78+
getHashedElem (UnboxedColumn (c :: VU.Vector a)) j = hash' @a (c VU.! j)
79+
getHashedElem (OptionalColumn (c :: V.Vector a)) j = hash' @a (c V.! j)
80+
getHashedElem _ _ = 0
81+
mkHash j = getHashedElem ((V.!) (columns df) j) i
11982

12083
-- | This hash function returns the hash when given a non numeric type but
12184
-- the value when given a numeric.

0 commit comments

Comments
 (0)