-- Tree module
-- By Gregory W. Schwartz

-- | Collects all functions pertaining to trees

{-# LANGUAGE BangPatterns #-}

module Math.TreeFun.Tree where

-- Built-in
import Data.List
import Data.Tree
import qualified Data.Map as M
import Data.Maybe
import qualified Data.Sequence as S
import Control.Applicative
import qualified Data.Foldable as F
import Control.Monad.State

-- Local
import Math.TreeFun.Types

-- | Convert a bool to an integer
boolToInt :: Bool -> Int
boolToInt :: Bool -> Int
boolToInt Bool
True  = Int
1
boolToInt Bool
False = Int
0

-- | Find out if a node is a leaf or not
isLeaf :: Tree a -> Bool
isLeaf :: forall a. Tree a -> Bool
isLeaf (Node { subForest :: forall a. Tree a -> [Tree a]
subForest = [] }) = Bool
True
isLeaf Tree a
_                         = Bool
False

-- | Return the labels of the leaves of the tree
leaves :: Tree a -> [a]
leaves :: forall a. Tree a -> [a]
leaves (Node { rootLabel :: forall a. Tree a -> a
rootLabel = a
x, subForest :: forall a. Tree a -> [Tree a]
subForest = [] }) = [a
x]
leaves (Node { rootLabel :: forall a. Tree a -> a
rootLabel = a
_, subForest :: forall a. Tree a -> [Tree a]
subForest = [Tree a]
xs }) = (Tree a -> [a]) -> [Tree a] -> [a]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Tree a -> [a]
forall a. Tree a -> [a]
leaves [Tree a]
xs

-- | Return the labels of the leaves of the tree with their relative heights
-- from the root (the input number you give determines how many steps away the
-- leaves are, should almost always start at 0)
leavesHeight :: (Ord a) => Int -> Tree a -> M.Map a Int
leavesHeight :: forall a. Ord a => Int -> Tree a -> Map a Int
leavesHeight !Int
h (Node { rootLabel :: forall a. Tree a -> a
rootLabel = a
x, subForest :: forall a. Tree a -> [Tree a]
subForest = [] }) = a -> Int -> Map a Int
forall k a. k -> a -> Map k a
M.singleton a
x Int
h
leavesHeight !Int
h (Node { rootLabel :: forall a. Tree a -> a
rootLabel = a
_, subForest :: forall a. Tree a -> [Tree a]
subForest = [Tree a]
xs }) =
    [Map a Int] -> Map a Int
forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
f (Map k a) -> Map k a
M.unions ([Map a Int] -> Map a Int)
-> ([Tree a] -> [Map a Int]) -> [Tree a] -> Map a Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tree a -> Map a Int) -> [Tree a] -> [Map a Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Tree a -> Map a Int
forall a. Ord a => Int -> Tree a -> Map a Int
leavesHeight (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) ([Tree a] -> Map a Int) -> [Tree a] -> Map a Int
forall a b. (a -> b) -> a -> b
$ [Tree a]
xs

-- | Return the labels of the leaves of the tree with their relative heights
-- from the root (the input number you give determines how many steps away the
-- leaves are, should almost always start at 0). Also, here we give leaves that
-- share a parent a separate label.
leavesCommonHeight :: (Ord a) => Int -> Tree a -> M.Map a (Int, Int)
leavesCommonHeight :: forall a. Ord a => Int -> Tree a -> Map a (Int, Int)
leavesCommonHeight Int
startHeight Tree a
tree = State Int (Map a (Int, Int)) -> Int -> Map a (Int, Int)
forall s a. State s a -> s -> a
evalState (Int -> Tree a -> State Int (Map a (Int, Int))
forall {m :: * -> *} {s} {t} {k}.
(MonadState s m, Num t, Num s, Ord k) =>
t -> Tree k -> m (Map k (t, s))
iter Int
startHeight Tree a
tree) Int
0
  where
    iter :: t -> Tree k -> m (Map k (t, s))
iter !t
h (Node { rootLabel :: forall a. Tree a -> a
rootLabel = k
x, subForest :: forall a. Tree a -> [Tree a]
subForest = [] }) = do
        label <- m s
forall s (m :: * -> *). MonadState s m => m s
get
        return $ M.singleton x (h, label)
    iter !t
h (Node { rootLabel :: forall a. Tree a -> a
rootLabel = k
_, subForest :: forall a. Tree a -> [Tree a]
subForest = [Tree k]
xs }) = do
        -- Get leaves and assign them the label
        ls    <- (Tree k -> m (Map k (t, s))) -> [Tree k] -> m [Map k (t, s)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (t -> Tree k -> m (Map k (t, s))
iter (t
h t -> t -> t
forall a. Num a => a -> a -> a
+ t
1)) ([Tree k] -> m [Map k (t, s)])
-> ([Tree k] -> [Tree k]) -> [Tree k] -> m [Map k (t, s)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tree k -> Bool) -> [Tree k] -> [Tree k]
forall a. (a -> Bool) -> [a] -> [a]
filter Tree k -> Bool
forall a. Tree a -> Bool
isLeaf ([Tree k] -> m [Map k (t, s)]) -> [Tree k] -> m [Map k (t, s)]
forall a b. (a -> b) -> a -> b
$ [Tree k]
xs

        -- Increment label
        label <- get
        put $ label + 1

        -- Get rest of the trees
        ts    <- mapM (iter (h + 1)) . filter (not . isLeaf) $ xs
        -- Combine the results
        return . M.unions . (++) ts $ ls

-- | Return the labels of the leaves of the tree with their weights
-- determined by the product of the number of children of their parents all
-- the way up to the root, along with their distance. Returns Double for
-- more precision.
leavesParentMult :: (Ord a) => Double
                            -> Double
                            -> Tree a
                            -> M.Map a (Double, Double)
leavesParentMult :: forall a.
Ord a =>
Double -> Double -> Tree a -> Map a (Double, Double)
leavesParentMult !Double
w !Double
d (Node { rootLabel :: forall a. Tree a -> a
rootLabel = a
x, subForest :: forall a. Tree a -> [Tree a]
subForest = [] }) =
    a -> (Double, Double) -> Map a (Double, Double)
forall k a. k -> a -> Map k a
M.singleton a
x (Double
w, Double
d)
leavesParentMult !Double
w !Double
d (Node { rootLabel :: forall a. Tree a -> a
rootLabel = a
_, subForest :: forall a. Tree a -> [Tree a]
subForest = [Tree a]
xs }) =
    [Map a (Double, Double)] -> Map a (Double, Double)
forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
f (Map k a) -> Map k a
M.unions ([Map a (Double, Double)] -> Map a (Double, Double))
-> ([Tree a] -> [Map a (Double, Double)])
-> [Tree a]
-> Map a (Double, Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tree a -> Map a (Double, Double))
-> [Tree a] -> [Map a (Double, Double)]
forall a b. (a -> b) -> [a] -> [b]
map (Double -> Double -> Tree a -> Map a (Double, Double)
forall a.
Ord a =>
Double -> Double -> Tree a -> Map a (Double, Double)
leavesParentMult (Double
w Double -> Double -> Double
forall a. Num a => a -> a -> a
* [Tree a] -> Double
forall i a. Num i => [a] -> i
genericLength [Tree a]
xs) (Double
d Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1)) ([Tree a] -> Map a (Double, Double))
-> [Tree a] -> Map a (Double, Double)
forall a b. (a -> b) -> a -> b
$ [Tree a]
xs

-- | Return the labels of the leaves of the tree with their weights
-- determined by the product of the number of children of their parents all
-- the way up to the root. Also, here we give leaves that share a parent
-- a separate label.
leavesCommonParentMult :: (Ord a) => Int -> Tree a -> M.Map a (Int, Int)
leavesCommonParentMult :: forall a. Ord a => Int -> Tree a -> Map a (Int, Int)
leavesCommonParentMult Int
numChildren Tree a
tree = State Int (Map a (Int, Int)) -> Int -> Map a (Int, Int)
forall s a. State s a -> s -> a
evalState (Int -> Tree a -> State Int (Map a (Int, Int))
forall {m :: * -> *} {s} {k}.
(MonadState s m, Num s, Ord k) =>
Int -> Tree k -> m (Map k (Int, s))
iter Int
numChildren Tree a
tree) Int
0
  where
    iter :: Int -> Tree k -> m (Map k (Int, s))
iter Int
multChildren (Node { rootLabel :: forall a. Tree a -> a
rootLabel = k
x, subForest :: forall a. Tree a -> [Tree a]
subForest = [] }) = do
        label <- m s
forall s (m :: * -> *). MonadState s m => m s
get
        return $ M.singleton x (multChildren, label)
    iter Int
multChildren (Node { rootLabel :: forall a. Tree a -> a
rootLabel = k
_, subForest :: forall a. Tree a -> [Tree a]
subForest = [Tree k]
xs }) = do
        -- Get leaves and assign them the label
        ls    <- (Tree k -> m (Map k (Int, s))) -> [Tree k] -> m [Map k (Int, s)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Int -> Tree k -> m (Map k (Int, s))
iter (Int
multChildren Int -> Int -> Int
forall a. Num a => a -> a -> a
* [Tree k] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tree k]
xs)) ([Tree k] -> m [Map k (Int, s)])
-> ([Tree k] -> [Tree k]) -> [Tree k] -> m [Map k (Int, s)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tree k -> Bool) -> [Tree k] -> [Tree k]
forall a. (a -> Bool) -> [a] -> [a]
filter Tree k -> Bool
forall a. Tree a -> Bool
isLeaf ([Tree k] -> m [Map k (Int, s)]) -> [Tree k] -> m [Map k (Int, s)]
forall a b. (a -> b) -> a -> b
$ [Tree k]
xs

        -- Increment label
        label <- get
        put $ label + 1

        -- Get rest of the trees
        ts    <- mapM (iter (multChildren * length xs))
               . filter (not . isLeaf)
               $ xs
        -- Combine the results
        return . M.unions . (++) ts $ ls

-- | Return the labels of the leaves of the tree with their relative heights
-- from the root (the input number you give determines how many steps away the
-- leaves are, should almost always start at 0), slower version not requiring
-- Ord but no Maps
leavesHeightList :: Int -> Tree a -> [(a, Int)]
leavesHeightList :: forall a. Int -> Tree a -> [(a, Int)]
leavesHeightList Int
h (Node { rootLabel :: forall a. Tree a -> a
rootLabel = a
x, subForest :: forall a. Tree a -> [Tree a]
subForest = [] }) = [(a
x, Int
h)]
leavesHeightList Int
h (Node { rootLabel :: forall a. Tree a -> a
rootLabel = a
_, subForest :: forall a. Tree a -> [Tree a]
subForest = [Tree a]
xs }) =
    (Tree a -> [(a, Int)]) -> [Tree a] -> [(a, Int)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int -> Tree a -> [(a, Int)]
forall a. Int -> Tree a -> [(a, Int)]
leavesHeightList (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) [Tree a]
xs

-- | Return the inner nodes of the tree
innerNodes :: Tree a -> [a]
innerNodes :: forall a. Tree a -> [a]
innerNodes (Node { rootLabel :: forall a. Tree a -> a
rootLabel = a
_, subForest :: forall a. Tree a -> [Tree a]
subForest = [] }) = []
innerNodes (Node { rootLabel :: forall a. Tree a -> a
rootLabel = a
x, subForest :: forall a. Tree a -> [Tree a]
subForest = [Tree a]
xs }) = a
x
                                                    a -> [a] -> [a]
forall a. a -> [a] -> [a]
: (Tree a -> [a]) -> [Tree a] -> [a]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Tree a -> [a]
forall a. Tree a -> [a]
innerNodes [Tree a]
xs

-- | Return the number of leaves in a tree
numLeaves :: (Num b) => Tree a -> b
numLeaves :: forall b a. Num b => Tree a -> b
numLeaves = [a] -> b
forall i a. Num i => [a] -> i
genericLength ([a] -> b) -> (Tree a -> [a]) -> Tree a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tree a -> [a]
forall a. Tree a -> [a]
leaves

-- | Return the number of inner nodes of a tree
numInner :: (Num b) => Tree a -> b
numInner :: forall b a. Num b => Tree a -> b
numInner = [a] -> b
forall i a. Num i => [a] -> i
genericLength ([a] -> b) -> (Tree a -> [a]) -> Tree a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tree a -> [a]
forall a. Tree a -> [a]
innerNodes

-- | Return True if a tree has a leaf connected to the root of the given
-- tree
hasRootLeaf :: Tree a -> Bool
hasRootLeaf :: forall a. Tree a -> Bool
hasRootLeaf (Node { subForest :: forall a. Tree a -> [Tree a]
subForest = [Tree a]
ts }) = Bool -> Bool
not (Bool -> Bool) -> ([Tree a] -> Bool) -> [Tree a] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Tree a] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([Tree a] -> Bool) -> ([Tree a] -> [Tree a]) -> [Tree a] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tree a -> Bool) -> [Tree a] -> [Tree a]
forall a. (a -> Bool) -> [a] -> [a]
filter Tree a -> Bool
forall a. Tree a -> Bool
isLeaf ([Tree a] -> Bool) -> [Tree a] -> Bool
forall a b. (a -> b) -> a -> b
$ [Tree a]
ts

-- | Return the list of root leaves
getRootLeaves :: Tree a -> [a]
getRootLeaves :: forall a. Tree a -> [a]
getRootLeaves (Node { subForest :: forall a. Tree a -> [Tree a]
subForest = [Tree a]
ts }) = (Tree a -> a) -> [Tree a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> a
forall a. Tree a -> a
rootLabel ([Tree a] -> [a]) -> ([Tree a] -> [Tree a]) -> [Tree a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tree a -> Bool) -> [Tree a] -> [Tree a]
forall a. (a -> Bool) -> [a] -> [a]
filter Tree a -> Bool
forall a. Tree a -> Bool
isLeaf ([Tree a] -> [a]) -> [Tree a] -> [a]
forall a b. (a -> b) -> a -> b
$ [Tree a]
ts

-- | Return the list of properties in a property map for a tree
getProperties :: (Eq b) => PropertyMap a b -> [b]
getProperties :: forall b a. Eq b => PropertyMap a b -> [b]
getProperties = [b] -> [b]
forall a. Eq a => [a] -> [a]
nub ([b] -> [b]) -> (PropertyMap a b -> [b]) -> PropertyMap a b -> [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seq b -> [b]
forall a. Seq a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
F.toList (Seq b -> [b])
-> (PropertyMap a b -> Seq b) -> PropertyMap a b -> [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Seq b -> Seq b -> Seq b) -> Seq b -> [Seq b] -> Seq b
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
F.foldl' Seq b -> Seq b -> Seq b
forall a. Seq a -> Seq a -> Seq a
(S.><) Seq b
forall a. Seq a
S.empty ([Seq b] -> Seq b)
-> (PropertyMap a b -> [Seq b]) -> PropertyMap a b -> Seq b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PropertyMap a b -> [Seq b]
forall k a. Map k a -> [a]
M.elems

-- | Remove leaves from a tree
filterLeaves :: Tree a -> Tree a
filterLeaves :: forall a. Tree a -> Tree a
filterLeaves Tree a
tree = Tree a
tree {subForest = filter (not . isLeaf) . subForest $ tree}

-- | Remove leaves attached to the root of the tree
filterRootLeaves :: Tree a -> Tree a
filterRootLeaves :: forall a. Tree a -> Tree a
filterRootLeaves root :: Tree a
root@(Node { subForest :: forall a. Tree a -> [Tree a]
subForest = [Tree a]
ts }) =
    Tree a
root { subForest = filter (not . isLeaf) ts }

-- | Return the map of distances from each leaf to another leaf
getDistanceMap :: (Eq a, Ord a) => Tree a -> DistanceMap a
getDistanceMap :: forall a. (Eq a, Ord a) => Tree a -> DistanceMap a
getDistanceMap Tree a
tree = (Map Int (Seq a) -> Map Int (Seq a) -> Map Int (Seq a))
-> [(a, Map Int (Seq a))] -> Map a (Map Int (Seq a))
forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
M.fromListWith ((Seq a -> Seq a -> Seq a)
-> Map Int (Seq a) -> Map Int (Seq a) -> Map Int (Seq a)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Seq a -> Seq a -> Seq a
forall a. Seq a -> Seq a -> Seq a
(S.><))
                    ([(a, Map Int (Seq a))] -> Map a (Map Int (Seq a)))
-> [(a, Map Int (Seq a))] -> Map a (Map Int (Seq a))
forall a b. (a -> b) -> a -> b
$ (\a
x a
y -> if a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y
                                   then (a
x, Int -> Seq a -> Map Int (Seq a)
forall k a. k -> a -> Map k a
M.singleton Int
0 (a -> Seq a
forall a. a -> Seq a
S.singleton a
y))
                                   else ( a
x
                                        , Int -> Seq a -> Map Int (Seq a)
forall k a. k -> a -> Map k a
M.singleton
                                          (Tree a -> a -> a -> Int
forall a. Eq a => Tree a -> a -> a -> Int
getDistance Tree a
tree a
x a
y)
                                          (a -> Seq a
forall a. a -> Seq a
S.singleton a
y) ) )
                  (a -> a -> (a, Map Int (Seq a)))
-> [a] -> [a -> (a, Map Int (Seq a))]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tree a -> [a]
forall a. Tree a -> [a]
leaves Tree a
tree
                  [a -> (a, Map Int (Seq a))] -> [a] -> [(a, Map Int (Seq a))]
forall a b. [a -> b] -> [a] -> [b]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Tree a -> [a]
forall a. Tree a -> [a]
leaves Tree a
tree

-- | Find the distance between two leaves in a tree.
getDistance :: (Eq a) => Tree a -> a -> a -> Int
getDistance :: forall a. Eq a => Tree a -> a -> a -> Int
getDistance (Node { rootLabel :: forall a. Tree a -> a
rootLabel = a
l, subForest :: forall a. Tree a -> [Tree a]
subForest = [] }) a
x a
y = Bool -> Int
boolToInt
                                                         (Bool -> Int) -> Bool -> Int
forall a b. (a -> b) -> a -> b
$ a
l a -> [a] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [a
x, a
y]
getDistance n :: Tree a
n@(Node { rootLabel :: forall a. Tree a -> a
rootLabel = a
_, subForest :: forall a. Tree a -> [Tree a]
subForest = [Tree a]
xs }) a
x a
y
    | Bool
none      = Int
0
    | Bool
otherwise = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum
                ([Int] -> Int) -> ([Tree a] -> [Int]) -> [Tree a] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:) (Bool -> Int
boolToInt Bool
notShared)
                ([Int] -> [Int]) -> ([Tree a] -> [Int]) -> [Tree a] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tree a -> Int) -> [Tree a] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (\Tree a
t -> Tree a -> a -> a -> Int
forall a. Eq a => Tree a -> a -> a -> Int
getDistance Tree a
t a
x a
y)
                ([Tree a] -> Int) -> [Tree a] -> Int
forall a b. (a -> b) -> a -> b
$ [Tree a]
xs
  where
    -- Only count nodes that have one or the other, not shared or empty
    notShared :: Bool
notShared = (a -> [a] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem a
x [a]
ls) Bool -> Bool -> Bool
|| (a -> [a] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem a
y [a]
ls) Bool -> Bool -> Bool
&& Bool -> Bool
not (a -> [a] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem a
x [a]
ls Bool -> Bool -> Bool
&& a -> [a] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem a
y [a]
ls)
      where
        ls :: [a]
ls = Tree a -> [a]
forall a. Tree a -> [a]
leaves Tree a
n
    none :: Bool
none = Bool -> Bool
not (a -> [a] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem a
x [a]
ls Bool -> Bool -> Bool
|| a -> [a] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem a
y [a]
ls)
      where
        ls :: [a]
ls = Tree a -> [a]
forall a. Tree a -> [a]
leaves Tree a
n

-- | Return the map of distances from each leaf to another leaf
getDistanceMapSuperNode :: (Eq a, Ord a) => Tree (SuperNode a) -> DistanceMap a
getDistanceMapSuperNode :: forall a. (Eq a, Ord a) => Tree (SuperNode a) -> DistanceMap a
getDistanceMapSuperNode Tree (SuperNode a)
tree = (Map Int (Seq a) -> Map Int (Seq a) -> Map Int (Seq a))
-> [(a, Map Int (Seq a))] -> Map a (Map Int (Seq a))
forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
M.fromListWith ((Seq a -> Seq a -> Seq a)
-> Map Int (Seq a) -> Map Int (Seq a) -> Map Int (Seq a)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Seq a -> Seq a -> Seq a
forall a. Seq a -> Seq a -> Seq a
(S.><))
                             ([(a, Map Int (Seq a))] -> Map a (Map Int (Seq a)))
-> [(a, Map Int (Seq a))] -> Map a (Map Int (Seq a))
forall a b. (a -> b) -> a -> b
$ (\a
x a
y -> if a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y
                                            then
                                                (a
x , Int -> Seq a -> Map Int (Seq a)
forall k a. k -> a -> Map k a
M.singleton Int
0 (a -> Seq a
forall a. a -> Seq a
S.singleton a
y))
                                            else ( a
x
                                                 , Int -> Seq a -> Map Int (Seq a)
forall k a. k -> a -> Map k a
M.singleton
                                                   (Tree (SuperNode a) -> a -> a -> Int
forall a. (Eq a, Ord a) => Tree (SuperNode a) -> a -> a -> Int
getDistanceSuperNode Tree (SuperNode a)
tree a
x a
y)
                                                   (a -> Seq a
forall a. a -> Seq a
S.singleton a
y) ) )
                           (a -> a -> (a, Map Int (Seq a)))
-> [a] -> [a -> (a, Map Int (Seq a))]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [a]
allLeaves
                           [a -> (a, Map Int (Seq a))] -> [a] -> [(a, Map Int (Seq a))]
forall a b. [a -> b] -> [a] -> [b]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [a]
allLeaves
  where
    allLeaves :: [a]
allLeaves = Map a (Int, Int) -> [a]
forall k a. Map k a -> [k]
M.keys (Map a (Int, Int) -> [a])
-> (Tree (SuperNode a) -> Map a (Int, Int))
-> Tree (SuperNode a)
-> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SuperNode a -> Map a (Int, Int)
forall a. SuperNode a -> Map a (Int, Int)
myLeaves (SuperNode a -> Map a (Int, Int))
-> (Tree (SuperNode a) -> SuperNode a)
-> Tree (SuperNode a)
-> Map a (Int, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tree (SuperNode a) -> SuperNode a
forall a. Tree a -> a
rootLabel (Tree (SuperNode a) -> [a]) -> Tree (SuperNode a) -> [a]
forall a b. (a -> b) -> a -> b
$ Tree (SuperNode a)
tree

-- | Find the distance between two leaves in a leafNode tree. Begin recording
-- distances when record is True (should have height starting at 0)
getDistanceSuperNode :: (Eq a, Ord a) => Tree (SuperNode a) -> a -> a -> Int
getDistanceSuperNode :: forall a. (Eq a, Ord a) => Tree (SuperNode a) -> a -> a -> Int
getDistanceSuperNode (Node { rootLabel :: forall a. Tree a -> a
rootLabel = SuperNode { myLeaves :: forall a. SuperNode a -> Map a (Int, Int)
myLeaves = Map a (Int, Int)
ls
                                                   , myParent :: forall a. SuperNode a -> SuperNode a
myParent = SuperNode a
p }
                           , subForest :: forall a. Tree a -> [Tree a]
subForest = [Tree (SuperNode a)]
ts } ) a
x a
y
    | Map a (Int, Int) -> Bool
forall {a}. Map a a -> Bool
shared Map a (Int, Int)
ls    = [Int] -> Int
forall a. HasCallStack => [a] -> a
head
                   ([Int] -> Int)
-> ([Tree (SuperNode a)] -> [Int]) -> [Tree (SuperNode a)] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1)
                   ([Int] -> [Int])
-> ([Tree (SuperNode a)] -> [Int]) -> [Tree (SuperNode a)] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tree (SuperNode a) -> Int) -> [Tree (SuperNode a)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (\Tree (SuperNode a)
a -> Tree (SuperNode a) -> a -> a -> Int
forall a. (Eq a, Ord a) => Tree (SuperNode a) -> a -> a -> Int
getDistanceSuperNode Tree (SuperNode a)
a a
x a
y)
                   ([Tree (SuperNode a)] -> Int) -> [Tree (SuperNode a)] -> Int
forall a b. (a -> b) -> a -> b
$ [Tree (SuperNode a)]
ts
    | Map a (Int, Int) -> Bool
forall {a}. Map a a -> Bool
notShared Map a (Int, Int)
ls = a -> SuperNode a -> Int
forall {k}. Ord k => k -> SuperNode k -> Int
getParentLeafDist a
x SuperNode a
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ a -> SuperNode a -> Int
forall {k}. Ord k => k -> SuperNode k -> Int
getParentLeafDist a
y SuperNode a
p
    | Bool
otherwise    = Int
0
  where
    -- Only count nodes that have one or the other, not shared or empty
    notShared :: Map a a -> Bool
notShared Map a a
xs = (a -> Map a a -> Bool
forall k a. Ord k => k -> Map k a -> Bool
M.member a
x Map a a
xs Bool -> Bool -> Bool
|| a -> Map a a -> Bool
forall k a. Ord k => k -> Map k a -> Bool
M.member a
y Map a a
xs)
                Bool -> Bool -> Bool
&& Bool -> Bool
not (a -> Map a a -> Bool
forall k a. Ord k => k -> Map k a -> Bool
M.member a
x Map a a
xs Bool -> Bool -> Bool
&& a -> Map a a -> Bool
forall k a. Ord k => k -> Map k a -> Bool
M.member a
y Map a a
xs)
    shared :: Map a a -> Bool
shared Map a a
xs    = a -> Map a a -> Bool
forall k a. Ord k => k -> Map k a -> Bool
M.member a
x Map a a
xs Bool -> Bool -> Bool
&& a -> Map a a -> Bool
forall k a. Ord k => k -> Map k a -> Bool
M.member a
y Map a a
xs
    getParentLeafDist :: k -> SuperNode k -> Int
getParentLeafDist k
a SuperNode k
b = (Int, Int) -> Int
forall a b. (a, b) -> a
fst ((Int, Int) -> Int)
-> (SuperNode k -> (Int, Int)) -> SuperNode k -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (Int, Int) -> (Int, Int)
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (Int, Int) -> (Int, Int))
-> (SuperNode k -> Maybe (Int, Int)) -> SuperNode k -> (Int, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. k -> Map k (Int, Int) -> Maybe (Int, Int)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
a (Map k (Int, Int) -> Maybe (Int, Int))
-> (SuperNode k -> Map k (Int, Int))
-> SuperNode k
-> Maybe (Int, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SuperNode k -> Map k (Int, Int)
forall a. SuperNode a -> Map a (Int, Int)
myLeaves (SuperNode k -> Int) -> SuperNode k -> Int
forall a b. (a -> b) -> a -> b
$ SuperNode k
b

-- | Get the sum of a tree for a tree with numbered labels
sumTree :: (Num a) => Tree a -> a
sumTree :: forall a. Num a => Tree a -> a
sumTree = (a -> a -> a) -> a -> Tree a -> a
forall b a. (b -> a -> b) -> b -> Tree a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
F.foldl' a -> a -> a
forall a. Num a => a -> a -> a
(+) a
0

-- | Convert a tree to the LeafNode tree data structure (the leaves are in the
-- nodes)
toSuperNodeTree :: (Ord a) => SuperNode a -> Tree a -> Tree (SuperNode a)
toSuperNodeTree :: forall a. Ord a => SuperNode a -> Tree a -> Tree (SuperNode a)
toSuperNodeTree SuperNode a
p n :: Tree a
n@(Node { rootLabel :: forall a. Tree a -> a
rootLabel = a
x, subForest :: forall a. Tree a -> [Tree a]
subForest = [Tree a]
xs }) =
    Node { rootLabel :: SuperNode a
rootLabel = SuperNode a
newNode
         , subForest :: [Tree (SuperNode a)]
subForest = (Tree a -> Tree (SuperNode a)) -> [Tree a] -> [Tree (SuperNode a)]
forall a b. (a -> b) -> [a] -> [b]
map (SuperNode a -> Tree a -> Tree (SuperNode a)
forall a. Ord a => SuperNode a -> Tree a -> Tree (SuperNode a)
toSuperNodeTree SuperNode a
newNode) [Tree a]
xs }
  where
    newNode :: SuperNode a
newNode = SuperNode { myRootLabel :: a
myRootLabel = a
x
                        , myLeaves :: Map a (Int, Int)
myLeaves = Int -> Tree a -> Map a (Int, Int)
forall a. Ord a => Int -> Tree a -> Map a (Int, Int)
leavesCommonHeight Int
0 Tree a
n
                        , myParent :: SuperNode a
myParent = SuperNode a
p }