-- |
-- Module      : Tile.Execution.Concurrent
-- Description : Actor-style concurrent interpreter for schedules.
--
-- These functions interpret divergence schedules using lightweight
-- Haskell concurrency through channels and forked threads. A divergence
-- schedule has edges directed from root toward leaves; it is the form
-- produced by 'Tile.Routing.buildSchedule'. Collectives that require
-- leaf-to-root message flow (reduce, gather) derive the convergence
-- schedule internally.
--
-- The @run*@ forms return the observed result without tracing; the
-- @run*WithTrace@ forms also report structured 'Trace' events. Their
-- correctness contract is stated by the pure functions in
-- "Tile.Execution".
--
-- Precondition shared by all functions: the schedule must be rooted at
-- the supplied @root@. Passing a disconnected schedule may leave worker
-- threads waiting for messages that never arrive.
module Tile.Execution.Concurrent
  ( Trace (..),
    runBroadcast,
    runBroadcastWithTrace,
    runGather,
    runGatherWithTrace,
    runReduce,
    runReduceWithTrace,
    runScatter,
    runScatterWithTrace,
    runAllReduce,
    runAllReduceWithTrace,
  )
where

import Control.Concurrent
import Control.Monad
import Data.Map.Strict qualified as Map
import Data.Set qualified as Set
import Tile.Schedule
import Tile.Tree (RoutedTree (..), scheduleTree, treeIndex, treeLabels)

-- | An observed event in the concurrent schedule interpreter.
data Trace m msg
  = Received m msg
  | Sent m m msg
  | Completed m msg
  deriving (Int -> Trace m msg -> ShowS
[Trace m msg] -> ShowS
Trace m msg -> String
(Int -> Trace m msg -> ShowS)
-> (Trace m msg -> String)
-> ([Trace m msg] -> ShowS)
-> Show (Trace m msg)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall m msg. (Show m, Show msg) => Int -> Trace m msg -> ShowS
forall m msg. (Show m, Show msg) => [Trace m msg] -> ShowS
forall m msg. (Show m, Show msg) => Trace m msg -> String
$cshowsPrec :: forall m msg. (Show m, Show msg) => Int -> Trace m msg -> ShowS
showsPrec :: Int -> Trace m msg -> ShowS
$cshow :: forall m msg. (Show m, Show msg) => Trace m msg -> String
show :: Trace m msg -> String
$cshowList :: forall m msg. (Show m, Show msg) => [Trace m msg] -> ShowS
showList :: [Trace m msg] -> ShowS
Show, Trace m msg -> Trace m msg -> Bool
(Trace m msg -> Trace m msg -> Bool)
-> (Trace m msg -> Trace m msg -> Bool) -> Eq (Trace m msg)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall m msg. (Eq m, Eq msg) => Trace m msg -> Trace m msg -> Bool
$c== :: forall m msg. (Eq m, Eq msg) => Trace m msg -> Trace m msg -> Bool
== :: Trace m msg -> Trace m msg -> Bool
$c/= :: forall m msg. (Eq m, Eq msg) => Trace m msg -> Trace m msg -> Bool
/= :: Trace m msg -> Trace m msg -> Bool
Eq)

-- | Run a broadcast divergence schedule without tracing.
runBroadcast ::
  (Ord m) =>
  Schedule m ->
  m ->
  p ->
  IO (Map.Map m p)
runBroadcast :: forall m p. Ord m => Schedule m -> m -> p -> IO (Map m p)
runBroadcast =
  (Trace m p -> IO ()) -> Schedule m -> m -> p -> IO (Map m p)
forall m p.
Ord m =>
(Trace m p -> IO ()) -> Schedule m -> m -> p -> IO (Map m p)
runBroadcastWithTrace (IO () -> Trace m p -> IO ()
forall a b. a -> b -> a
const (() -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()))

-- | Run a broadcast divergence schedule, reporting each observed action.
--
-- Precondition: the schedule is rooted at @root@.
runBroadcastWithTrace ::
  (Ord m) =>
  (Trace m p -> IO ()) ->
  Schedule m ->
  m ->
  p ->
  IO (Map.Map m p)
runBroadcastWithTrace :: forall m p.
Ord m =>
(Trace m p -> IO ()) -> Schedule m -> m -> p -> IO (Map m p)
runBroadcastWithTrace Trace m p -> IO ()
trace Schedule m
schedule m
root p
payload = do
  let graph :: Map m [m]
graph = Schedule m -> Map m [m]
forall a. Ord a => Schedule a -> Map a [a]
adjacencyList Schedule m
schedule
      members :: [m]
members =
        Set m -> [m]
forall a. Set a -> [a]
Set.toList (Set m -> [m]) -> Set m -> [m]
forall a b. (a -> b) -> a -> b
$
          [m] -> Set m
forall a. Ord a => [a] -> Set a
Set.fromList (m
root m -> [m] -> [m]
forall a. a -> [a] -> [a]
: Map m [m] -> [m]
forall k a. Map k a -> [k]
Map.keys Map m [m]
graph [m] -> [m] -> [m]
forall a. [a] -> [a] -> [a]
++ [[m]] -> [m]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (Map m [m] -> [[m]]
forall k a. Map k a -> [a]
Map.elems Map m [m]
graph))

  chans <- [m] -> (m -> IO (m, Chan p)) -> IO [(m, Chan p)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [m]
members ((m -> IO (m, Chan p)) -> IO [(m, Chan p)])
-> (m -> IO (m, Chan p)) -> IO [(m, Chan p)]
forall a b. (a -> b) -> a -> b
$ \m
m -> do
    ch <- IO (Chan p)
forall a. IO (Chan a)
newChan
    pure (m, ch)

  let chanMap = [(m, Chan p)] -> Map m (Chan p)
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(m, Chan p)]
chans
  resultChan <- newChan

  forM_ members $ \m
m -> do
    let inbox :: Chan p
inbox = Map m (Chan p)
chanMap Map m (Chan p) -> m -> Chan p
forall k a. Ord k => Map k a -> k -> a
Map.! m
m
        children :: [m]
children = [m] -> m -> Map m [m] -> [m]
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault [] m
m Map m [m]
graph
        childChans :: [(m, Chan p)]
childChans = [(m
c, Map m (Chan p)
chanMap Map m (Chan p) -> m -> Chan p
forall k a. Ord k => Map k a -> k -> a
Map.! m
c) | m
c <- [m]
children]
    _ <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
      msg <- Chan p -> IO p
forall a. Chan a -> IO a
readChan Chan p
inbox
      trace (Received m msg)
      writeChan resultChan (m, msg)
      forM_ childChans $ \(m
childName, Chan p
childInbox) -> do
        Trace m p -> IO ()
trace (m -> m -> p -> Trace m p
forall m msg. m -> m -> msg -> Trace m msg
Sent m
m m
childName p
msg)
        Chan p -> p -> IO ()
forall a. Chan a -> a -> IO ()
writeChan Chan p
childInbox p
msg
      trace (Completed m msg)
    pure ()

  writeChan (chanMap Map.! root) payload
  Map.fromList <$> replicateM (length members) (readChan resultChan)

incomingCounts :: (Ord a) => Schedule a -> Map.Map a Int
incomingCounts :: forall a. Ord a => Schedule a -> Map a Int
incomingCounts =
  (Step a -> Map a Int -> Map a Int)
-> Map a Int -> [Step a] -> Map a Int
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
    (\Step {to :: forall a. Step a -> a
to = a
c} Map a Int
m -> (Int -> Int -> Int) -> a -> Int -> Map a Int -> Map a Int
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
Map.insertWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) a
c Int
1 Map a Int
m)
    Map a Int
forall k a. Map k a
Map.empty

-- | Run a reduce divergence schedule without tracing.
--
-- Takes a divergence schedule. Leaf values flow toward the root and
-- are combined at each node in tree order, matching the fold order of
-- the pure 'Tile.Execution.reduceResult'. The combine function need
-- not be commutative.
runReduce ::
  (Ord m) =>
  Schedule m ->
  Map.Map m v ->
  (v -> v -> v) ->
  m ->
  IO v
runReduce :: forall m v.
Ord m =>
Schedule m -> Map m v -> (v -> v -> v) -> m -> IO v
runReduce =
  (Trace m v -> IO ())
-> Schedule m -> Map m v -> (v -> v -> v) -> m -> IO v
forall m v.
Ord m =>
(Trace m v -> IO ())
-> Schedule m -> Map m v -> (v -> v -> v) -> m -> IO v
runReduceWithTrace (IO () -> Trace m v -> IO ()
forall a b. a -> b -> a
const (() -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()))

-- | Run a reduce divergence schedule, reporting each observed action.
--
-- Children are folded in the same order as the pure
-- 'Tile.Execution.reduceResult': left-to-right over the divergence
-- tree. Each directed edge gets a dedicated channel, so arrival order
-- does not affect the result.
--
-- Precondition: the schedule is rooted at @root@. The value map must
-- contain every member reachable from @root@.
runReduceWithTrace ::
  (Ord m) =>
  (Trace m v -> IO ()) ->
  Schedule m ->
  Map.Map m v ->
  (v -> v -> v) ->
  m ->
  IO v
runReduceWithTrace :: forall m v.
Ord m =>
(Trace m v -> IO ())
-> Schedule m -> Map m v -> (v -> v -> v) -> m -> IO v
runReduceWithTrace Trace m v -> IO ()
trace Schedule m
schedule Map m v
initialValues v -> v -> v
combine m
root = do
  let childrenOf :: Map m [m]
childrenOf = Schedule m -> Map m [m]
forall a. Ord a => Schedule a -> Map a [a]
adjacencyList Schedule m
schedule
      parentOf :: Map m m
parentOf = [(m, m)] -> Map m m
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(m
child, m
parent) | Step {from :: forall a. Step a -> a
from = m
parent, to :: forall a. Step a -> a
to = m
child} <- Schedule m
schedule]
      members :: [m]
members =
        Set m -> [m]
forall a. Set a -> [a]
Set.toList (Set m -> [m]) -> Set m -> [m]
forall a b. (a -> b) -> a -> b
$
          [m] -> Set m
forall a. Ord a => [a] -> Set a
Set.fromList ([m] -> Set m) -> [m] -> Set m
forall a b. (a -> b) -> a -> b
$
            m
root m -> [m] -> [m]
forall a. a -> [a] -> [a]
: Map m [m] -> [m]
forall k a. Map k a -> [k]
Map.keys Map m [m]
childrenOf [m] -> [m] -> [m]
forall a. [a] -> [a] -> [a]
++ [[m]] -> [m]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (Map m [m] -> [[m]]
forall k a. Map k a -> [a]
Map.elems Map m [m]
childrenOf)

  -- One dedicated channel per directed edge (child → parent).
  edgeChans <- ([((m, m), Chan v)] -> Map (m, m) (Chan v))
-> IO [((m, m), Chan v)] -> IO (Map (m, m) (Chan v))
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [((m, m), Chan v)] -> Map (m, m) (Chan v)
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList (IO [((m, m), Chan v)] -> IO (Map (m, m) (Chan v)))
-> IO [((m, m), Chan v)] -> IO (Map (m, m) (Chan v))
forall a b. (a -> b) -> a -> b
$ Schedule m
-> (Step m -> IO ((m, m), Chan v)) -> IO [((m, m), Chan v)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Schedule m
schedule ((Step m -> IO ((m, m), Chan v)) -> IO [((m, m), Chan v)])
-> (Step m -> IO ((m, m), Chan v)) -> IO [((m, m), Chan v)]
forall a b. (a -> b) -> a -> b
$ \Step {from :: forall a. Step a -> a
from = m
parent, to :: forall a. Step a -> a
to = m
child} -> do
    ch <- IO (Chan v)
forall a. IO (Chan a)
newChan
    pure ((child, parent), ch)

  result <- newEmptyMVar

  forM_ members $ \m
m -> do
    let myChildren :: [m]
myChildren = [m] -> m -> Map m [m] -> [m]
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault [] m
m Map m [m]
childrenOf
        localValue :: v
localValue = Map m v
initialValues Map m v -> m -> v
forall k a. Ord k => Map k a -> k -> a
Map.! m
m

    _ <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
      childValues <- [m] -> (m -> IO v) -> IO [v]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [m]
myChildren ((m -> IO v) -> IO [v]) -> (m -> IO v) -> IO [v]
forall a b. (a -> b) -> a -> b
$ \m
child ->
        Chan v -> IO v
forall a. Chan a -> IO a
readChan (Map (m, m) (Chan v)
edgeChans Map (m, m) (Chan v) -> (m, m) -> Chan v
forall k a. Ord k => Map k a -> k -> a
Map.! (m
child, m
m))
      forM_ childValues $ \v
v -> Trace m v -> IO ()
trace (m -> v -> Trace m v
forall m msg. m -> msg -> Trace m msg
Received m
m v
v)
      let total = (v -> v -> v) -> v -> [v] -> v
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' v -> v -> v
combine v
localValue [v]
childValues
      case Map.lookup m parentOf of
        Maybe m
Nothing -> do
          Trace m v -> IO ()
trace (m -> v -> Trace m v
forall m msg. m -> msg -> Trace m msg
Completed m
m v
total)
          MVar v -> v -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar v
result v
total
        Just m
parent -> do
          Trace m v -> IO ()
trace (m -> m -> v -> Trace m v
forall m msg. m -> m -> msg -> Trace m msg
Sent m
m m
parent v
total)
          Chan v -> v -> IO ()
forall a. Chan a -> a -> IO ()
writeChan (Map (m, m) (Chan v)
edgeChans Map (m, m) (Chan v) -> (m, m) -> Chan v
forall k a. Ord k => Map k a -> k -> a
Map.! (m
m, m
parent)) v
total
    pure ()

  takeMVar result

-- | Run a gather divergence schedule without tracing.
--
-- Takes a divergence schedule. Values are collected in preorder over
-- the divergence tree, matching the pure 'Tile.Execution.gatherResult'.
runGather ::
  (Ord m) =>
  Schedule m ->
  Map.Map m a ->
  m ->
  IO [(m, a)]
runGather :: forall m a. Ord m => Schedule m -> Map m a -> m -> IO [(m, a)]
runGather =
  (Trace m [(m, a)] -> IO ())
-> Schedule m -> Map m a -> m -> IO [(m, a)]
forall m a.
Ord m =>
(Trace m [(m, a)] -> IO ())
-> Schedule m -> Map m a -> m -> IO [(m, a)]
runGatherWithTrace (IO () -> Trace m [(m, a)] -> IO ()
forall a b. a -> b -> a
const (() -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()))

-- | Run a gather divergence schedule, reporting each observed action.
--
-- Values are accumulated in preorder over the divergence tree: each
-- node prepends its own value before appending children in tree order,
-- matching 'Tile.Execution.gatherResult'. Each directed edge gets a
-- dedicated channel so arrival order does not affect the result.
--
-- Precondition: the schedule is rooted at @root@. The value map must
-- contain every member reachable from @root@.
runGatherWithTrace ::
  (Ord m) =>
  (Trace m [(m, a)] -> IO ()) ->
  Schedule m ->
  Map.Map m a ->
  m ->
  IO [(m, a)]
runGatherWithTrace :: forall m a.
Ord m =>
(Trace m [(m, a)] -> IO ())
-> Schedule m -> Map m a -> m -> IO [(m, a)]
runGatherWithTrace Trace m [(m, a)] -> IO ()
trace Schedule m
schedule Map m a
initialValues m
root = do
  let childrenOf :: Map m [m]
childrenOf = Schedule m -> Map m [m]
forall a. Ord a => Schedule a -> Map a [a]
adjacencyList Schedule m
schedule
      parentOf :: Map m m
parentOf = [(m, m)] -> Map m m
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(m
child, m
parent) | Step {from :: forall a. Step a -> a
from = m
parent, to :: forall a. Step a -> a
to = m
child} <- Schedule m
schedule]
      members :: [m]
members =
        Set m -> [m]
forall a. Set a -> [a]
Set.toList (Set m -> [m]) -> Set m -> [m]
forall a b. (a -> b) -> a -> b
$
          [m] -> Set m
forall a. Ord a => [a] -> Set a
Set.fromList ([m] -> Set m) -> [m] -> Set m
forall a b. (a -> b) -> a -> b
$
            m
root m -> [m] -> [m]
forall a. a -> [a] -> [a]
: Map m [m] -> [m]
forall k a. Map k a -> [k]
Map.keys Map m [m]
childrenOf [m] -> [m] -> [m]
forall a. [a] -> [a] -> [a]
++ [[m]] -> [m]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (Map m [m] -> [[m]]
forall k a. Map k a -> [a]
Map.elems Map m [m]
childrenOf)

  edgeChans <- ([((m, m), Chan [(m, a)])] -> Map (m, m) (Chan [(m, a)]))
-> IO [((m, m), Chan [(m, a)])] -> IO (Map (m, m) (Chan [(m, a)]))
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [((m, m), Chan [(m, a)])] -> Map (m, m) (Chan [(m, a)])
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList (IO [((m, m), Chan [(m, a)])] -> IO (Map (m, m) (Chan [(m, a)])))
-> IO [((m, m), Chan [(m, a)])] -> IO (Map (m, m) (Chan [(m, a)]))
forall a b. (a -> b) -> a -> b
$ Schedule m
-> (Step m -> IO ((m, m), Chan [(m, a)]))
-> IO [((m, m), Chan [(m, a)])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Schedule m
schedule ((Step m -> IO ((m, m), Chan [(m, a)]))
 -> IO [((m, m), Chan [(m, a)])])
-> (Step m -> IO ((m, m), Chan [(m, a)]))
-> IO [((m, m), Chan [(m, a)])]
forall a b. (a -> b) -> a -> b
$ \Step {from :: forall a. Step a -> a
from = m
parent, to :: forall a. Step a -> a
to = m
child} -> do
    ch <- IO (Chan [(m, a)])
forall a. IO (Chan a)
newChan
    pure ((child, parent), ch)

  result <- newEmptyMVar

  forM_ members $ \m
m -> do
    let myChildren :: [m]
myChildren = [m] -> m -> Map m [m] -> [m]
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault [] m
m Map m [m]
childrenOf
        localValue :: [(m, a)]
localValue = [(m
m, Map m a
initialValues Map m a -> m -> a
forall k a. Ord k => Map k a -> k -> a
Map.! m
m)]

    _ <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
      childLists <- [m] -> (m -> IO [(m, a)]) -> IO [[(m, a)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [m]
myChildren ((m -> IO [(m, a)]) -> IO [[(m, a)]])
-> (m -> IO [(m, a)]) -> IO [[(m, a)]]
forall a b. (a -> b) -> a -> b
$ \m
child ->
        Chan [(m, a)] -> IO [(m, a)]
forall a. Chan a -> IO a
readChan (Map (m, m) (Chan [(m, a)])
edgeChans Map (m, m) (Chan [(m, a)]) -> (m, m) -> Chan [(m, a)]
forall k a. Ord k => Map k a -> k -> a
Map.! (m
child, m
m))
      let allReceived = [[(m, a)]] -> [(m, a)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(m, a)]]
childLists
      unless (null allReceived) $
        trace (Received m allReceived)
      let gathered = [(m, a)]
localValue [(m, a)] -> [(m, a)] -> [(m, a)]
forall a. [a] -> [a] -> [a]
++ [(m, a)]
allReceived
      case Map.lookup m parentOf of
        Maybe m
Nothing -> do
          Trace m [(m, a)] -> IO ()
trace (m -> [(m, a)] -> Trace m [(m, a)]
forall m msg. m -> msg -> Trace m msg
Completed m
m [(m, a)]
gathered)
          MVar [(m, a)] -> [(m, a)] -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar [(m, a)]
result [(m, a)]
gathered
        Just m
parent -> do
          Trace m [(m, a)] -> IO ()
trace (m -> m -> [(m, a)] -> Trace m [(m, a)]
forall m msg. m -> m -> msg -> Trace m msg
Sent m
m m
parent [(m, a)]
gathered)
          Chan [(m, a)] -> [(m, a)] -> IO ()
forall a. Chan a -> a -> IO ()
writeChan (Map (m, m) (Chan [(m, a)])
edgeChans Map (m, m) (Chan [(m, a)]) -> (m, m) -> Chan [(m, a)]
forall k a. Ord k => Map k a -> k -> a
Map.! (m
m, m
parent)) [(m, a)]
gathered
    pure ()

  takeMVar result

-- | Run a scatter divergence schedule without tracing.
--
-- The root starts with a value for each destination. At each hop, the
-- payload is partitioned by the routed subtree below each child.
runScatter ::
  (Ord m) =>
  Schedule m ->
  [(m, a)] ->
  m ->
  IO (Map.Map m a)
runScatter :: forall m a. Ord m => Schedule m -> [(m, a)] -> m -> IO (Map m a)
runScatter =
  (Trace m [(m, a)] -> IO ())
-> Schedule m -> [(m, a)] -> m -> IO (Map m a)
forall m a.
Ord m =>
(Trace m [(m, a)] -> IO ())
-> Schedule m -> [(m, a)] -> m -> IO (Map m a)
runScatterWithTrace (IO () -> Trace m [(m, a)] -> IO ()
forall a b. a -> b -> a
const (() -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()))

-- | Run a scatter divergence schedule, reporting each observed action.
--
-- Precondition: the schedule is rooted at @root@.
runScatterWithTrace ::
  (Ord m) =>
  (Trace m [(m, a)] -> IO ()) ->
  Schedule m ->
  [(m, a)] ->
  m ->
  IO (Map.Map m a)
runScatterWithTrace :: forall m a.
Ord m =>
(Trace m [(m, a)] -> IO ())
-> Schedule m -> [(m, a)] -> m -> IO (Map m a)
runScatterWithTrace Trace m [(m, a)] -> IO ()
trace Schedule m
schedule [(m, a)]
initialValues m
root = do
  let graph :: Map m [m]
graph = Schedule m -> Map m [m]
forall a. Ord a => Schedule a -> Map a [a]
adjacencyList Schedule m
schedule
      incoming :: Map m Int
incoming = Schedule m -> Map m Int
forall a. Ord a => Schedule a -> Map a Int
incomingCounts Schedule m
schedule
      payloadMap :: Map m a
payloadMap = [(m, a)] -> Map m a
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(m, a)]
initialValues
      members :: [m]
members =
        Set m -> [m]
forall a. Set a -> [a]
Set.toList (Set m -> [m]) -> Set m -> [m]
forall a b. (a -> b) -> a -> b
$
          [m] -> Set m
forall a. Ord a => [a] -> Set a
Set.fromList ([m] -> Set m) -> [m] -> Set m
forall a b. (a -> b) -> a -> b
$
            m
root m -> [m] -> [m]
forall a. a -> [a] -> [a]
: Map m [m] -> [m]
forall k a. Map k a -> [k]
Map.keys Map m [m]
graph [m] -> [m] -> [m]
forall a. [a] -> [a] -> [a]
++ [[m]] -> [m]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (Map m [m] -> [[m]]
forall k a. Map k a -> [a]
Map.elems Map m [m]
graph) [m] -> [m] -> [m]
forall a. [a] -> [a] -> [a]
++ Map m a -> [m]
forall k a. Map k a -> [k]
Map.keys Map m a
payloadMap
      RoutedTree Tree m
routed = m -> Schedule m -> RoutedTree m
forall a. Ord a => a -> Schedule a -> RoutedTree a
scheduleTree m
root Schedule m
schedule
      routedSubtrees :: Map m (Tree m)
routedSubtrees = Tree m -> Map m (Tree m)
forall a. Ord a => Tree a -> Map a (Tree a)
treeIndex Tree m
routed
      reachablePayloads :: Map m a
reachablePayloads =
        Map m a -> Set m -> Map m a
forall k a. Ord k => Map k a -> Set k -> Map k a
Map.restrictKeys Map m a
payloadMap (Tree m -> Set m
forall a. Ord a => Tree a -> Set a
treeLabels Tree m
routed)

  chanPairs <- [m] -> (m -> IO (m, Chan [(m, a)])) -> IO [(m, Chan [(m, a)])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [m]
members ((m -> IO (m, Chan [(m, a)])) -> IO [(m, Chan [(m, a)])])
-> (m -> IO (m, Chan [(m, a)])) -> IO [(m, Chan [(m, a)])]
forall a b. (a -> b) -> a -> b
$ \m
m -> do
    ch <- IO (Chan [(m, a)])
forall a. IO (Chan a)
newChan
    pure (m, ch)

  let chanMap = [(m, Chan [(m, a)])] -> Map m (Chan [(m, a)])
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(m, Chan [(m, a)])]
chanPairs
  resultChan <- newChan

  forM_ members $ \m
m -> do
    let inbox :: Chan [(m, a)]
inbox = Map m (Chan [(m, a)])
chanMap Map m (Chan [(m, a)]) -> m -> Chan [(m, a)]
forall k a. Ord k => Map k a -> k -> a
Map.! m
m
        children :: [m]
children = [m] -> m -> Map m [m] -> [m]
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault [] m
m Map m [m]
graph
        childChans :: [(m, Chan [(m, a)])]
childChans = [(m
c, Map m (Chan [(m, a)])
chanMap Map m (Chan [(m, a)]) -> m -> Chan [(m, a)]
forall k a. Ord k => Map k a -> k -> a
Map.! m
c) | m
c <- [m]
children]
        expected :: Int
expected
          | m
m m -> m -> Bool
forall a. Eq a => a -> a -> Bool
== m
root = Int
1
          | Bool
otherwise = Int -> m -> Map m Int -> Int
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault Int
0 m
m Map m Int
incoming

    _ <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
      -- Scatter schedules are normally trees, so this usually reads
      -- one payload. For a general schedule, merge all incoming
      -- payload fragments before forwarding.
      payload <- [[(m, a)]] -> [(m, a)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(m, a)]] -> [(m, a)]) -> IO [[(m, a)]] -> IO [(m, a)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO [(m, a)] -> IO [[(m, a)]]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
expected (Chan [(m, a)] -> IO [(m, a)]
forall a. Chan a -> IO a
readChan Chan [(m, a)]
inbox)
      trace (Received m payload)
      forM_ (lookup m payload) $ \a
value ->
        Chan (m, a) -> (m, a) -> IO ()
forall a. Chan a -> a -> IO ()
writeChan Chan (m, a)
resultChan (m
m, a
value)

      forM_ childChans $ \(m
childName, Chan [(m, a)]
childInbox) -> do
        let childMembers :: Set m
childMembers =
              Set m -> (Tree m -> Set m) -> Maybe (Tree m) -> Set m
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Set m
forall a. Set a
Set.empty Tree m -> Set m
forall a. Ord a => Tree a -> Set a
treeLabels (m -> Map m (Tree m) -> Maybe (Tree m)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup m
childName Map m (Tree m)
routedSubtrees)
            childPayload :: [(m, a)]
childPayload =
              [ (m, a)
item
              | item :: (m, a)
item@(m
dest, a
_) <- [(m, a)]
payload,
                m
dest m -> Set m -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set m
childMembers
              ]
        Trace m [(m, a)] -> IO ()
trace (m -> m -> [(m, a)] -> Trace m [(m, a)]
forall m msg. m -> m -> msg -> Trace m msg
Sent m
m m
childName [(m, a)]
childPayload)
        Chan [(m, a)] -> [(m, a)] -> IO ()
forall a. Chan a -> a -> IO ()
writeChan Chan [(m, a)]
childInbox [(m, a)]
childPayload
      trace (Completed m payload)
    pure ()

  writeChan (chanMap Map.! root) (Map.toList payloadMap)
  Map.fromList <$> replicateM (Map.size reachablePayloads) (readChan resultChan)

-- | Run an all-reduce divergence schedule without tracing.
--
-- Takes a divergence schedule. Every member ends with the value
-- obtained by combining all member values with @combine@. Runs the
-- reduce phase to completion before starting the broadcast phase.
runAllReduce ::
  (Ord m) =>
  Schedule m ->
  m ->
  Map.Map m v ->
  (v -> v -> v) ->
  IO (Map.Map m v)
runAllReduce :: forall m v.
Ord m =>
Schedule m -> m -> Map m v -> (v -> v -> v) -> IO (Map m v)
runAllReduce =
  (Trace m v -> IO ())
-> Schedule m -> m -> Map m v -> (v -> v -> v) -> IO (Map m v)
forall m v.
Ord m =>
(Trace m v -> IO ())
-> Schedule m -> m -> Map m v -> (v -> v -> v) -> IO (Map m v)
runAllReduceWithTrace (IO () -> Trace m v -> IO ()
forall a b. a -> b -> a
const (() -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()))

-- | Run an all-reduce divergence schedule, reporting each observed
-- action.
--
-- Both the reduce phase and the broadcast phase emit 'Trace' events
-- through the same @tracer@. The reduce phase completes fully before
-- the broadcast phase begins.
--
-- Precondition: the schedule is rooted at @root@. The value map must
-- contain every member reachable from @root@.
runAllReduceWithTrace ::
  (Ord m) =>
  (Trace m v -> IO ()) ->
  Schedule m ->
  m ->
  Map.Map m v ->
  (v -> v -> v) ->
  IO (Map.Map m v)
runAllReduceWithTrace :: forall m v.
Ord m =>
(Trace m v -> IO ())
-> Schedule m -> m -> Map m v -> (v -> v -> v) -> IO (Map m v)
runAllReduceWithTrace Trace m v -> IO ()
tracer Schedule m
schedule m
root Map m v
values v -> v -> v
combine = do
  combined <- (Trace m v -> IO ())
-> Schedule m -> Map m v -> (v -> v -> v) -> m -> IO v
forall m v.
Ord m =>
(Trace m v -> IO ())
-> Schedule m -> Map m v -> (v -> v -> v) -> m -> IO v
runReduceWithTrace Trace m v -> IO ()
tracer Schedule m
schedule Map m v
values v -> v -> v
combine m
root
  runBroadcastWithTrace tracer schedule root combined