-- |
-- Module      : Tile.Affine
-- Description : Affine rank spaces and slicing.
--
-- An 'AffineRankSpace' maps logical coordinates to integer ranks via
-- an offset and per-dimension strides. The type is general: any
-- combination of offset, sizes, and strides is representable.
--
-- Functions divide into two classes:
--
-- [Unconditional]
--   'rankOf' and 'rankOfMaybe' compute
--   @offset + sum (zipWith (*) coord strides)@ for any strides.
--
-- [Row-major invariant required]
--   'pointOf' and 'pointOfMaybe' recover coordinates by mixed-radix
--   division and require @strides[k] = product(sizes[k+1..])@ at
--   every level. 'rowMajor' establishes this invariant; 'select' and
--   'fixDim' preserve it.
module Tile.Affine
  ( -- * Affine rank spaces
    AffineRankSpace (..),
    Point,
    rowMajor,

    -- * Coordinate/rank conversion
    rankOf,
    rankOfMaybe,
    pointOf,
    pointOfMaybe,

    -- * Queries
    spaceExtent,
    points,
    ranks,

    -- * Slicing
    select,
    fixDim,
  )
where

import Control.Monad (guard)
import Tile.Shape (Shape)

-- | An affine map from logical coordinates to ranks.
--
-- For a coordinate @coord@, the rank is:
--
-- @
-- offset + sum (zipWith (*) coord strides)
-- @
data AffineRankSpace = AffineRankSpace
  { -- | Rank of the origin coordinate.
    AffineRankSpace -> Int
offset :: Int,
    -- | Extents, or shape, of the affine rank space.
    AffineRankSpace -> [Int]
sizes :: [Int],
    -- | Rank stride for each logical dimension.
    AffineRankSpace -> [Int]
strides :: [Int]
  }
  deriving (Int -> AffineRankSpace -> ShowS
[AffineRankSpace] -> ShowS
AffineRankSpace -> String
(Int -> AffineRankSpace -> ShowS)
-> (AffineRankSpace -> String)
-> ([AffineRankSpace] -> ShowS)
-> Show AffineRankSpace
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> AffineRankSpace -> ShowS
showsPrec :: Int -> AffineRankSpace -> ShowS
$cshow :: AffineRankSpace -> String
show :: AffineRankSpace -> String
$cshowList :: [AffineRankSpace] -> ShowS
showList :: [AffineRankSpace] -> ShowS
Show, AffineRankSpace -> AffineRankSpace -> Bool
(AffineRankSpace -> AffineRankSpace -> Bool)
-> (AffineRankSpace -> AffineRankSpace -> Bool)
-> Eq AffineRankSpace
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: AffineRankSpace -> AffineRankSpace -> Bool
== :: AffineRankSpace -> AffineRankSpace -> Bool
$c/= :: AffineRankSpace -> AffineRankSpace -> Bool
/= :: AffineRankSpace -> AffineRankSpace -> Bool
Eq)

-- | A logical coordinate in an affine rank space.
type Point = [Int]

-- | Construct the row-major affine rank space for a shape.
--
-- Strides satisfy @strides[k] = product(sizes[k+1..])@, which is the
-- row-major invariant required by 'pointOf' and 'pointOfMaybe'.
-- 'select' and 'fixDim' preserve this invariant.
rowMajor :: Shape -> AffineRankSpace
rowMajor :: [Int] -> AffineRankSpace
rowMajor [Int]
shape =
  AffineRankSpace
    { offset :: Int
offset = Int
0,
      sizes :: [Int]
sizes = [Int]
shape,
      strides :: [Int]
strides = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
1 ((Int -> Int -> Int) -> Int -> [Int] -> [Int]
forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) Int
1 [Int]
shape)
    }

-- | Convert a coordinate to a rank.
--
-- Unconditional: works for any 'AffineRankSpace'. Throws an error if
-- the coordinate has the wrong dimension or is out of bounds. Use
-- 'rankOfMaybe' for a total variant.
rankOf :: AffineRankSpace -> Point -> Int
rankOf :: AffineRankSpace -> [Int] -> Int
rankOf AffineRankSpace
space [Int]
coord =
  case AffineRankSpace -> [Int] -> Maybe Int
rankOfMaybe AffineRankSpace
space [Int]
coord of
    Just Int
rank -> Int
rank
    Maybe Int
Nothing -> String -> Int
forall a. HasCallStack => String -> a
error String
"rankOf: coordinate dimension mismatch"

-- | Convert a coordinate to a rank, returning 'Nothing' for invalid
-- coordinates.
--
-- Unconditional: works for any 'AffineRankSpace'.
rankOfMaybe :: AffineRankSpace -> Point -> Maybe Int
rankOfMaybe :: AffineRankSpace -> [Int] -> Maybe Int
rankOfMaybe AffineRankSpace
space [Int]
coord
  | [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
coord Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (AffineRankSpace -> [Int]
strides AffineRankSpace
space) Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((Int -> Int -> Bool) -> [Int] -> [Int] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Bool
forall {a}. (Ord a, Num a) => a -> a -> Bool
inBounds [Int]
coord (AffineRankSpace -> [Int]
sizes AffineRankSpace
space)) =
      Int -> Maybe Int
forall a. a -> Maybe a
Just (AffineRankSpace -> Int
offset AffineRankSpace
space Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
coord (AffineRankSpace -> [Int]
strides AffineRankSpace
space)))
  | Bool
otherwise = Maybe Int
forall a. Maybe a
Nothing
  where
    inBounds :: a -> a -> Bool
inBounds a
coordinate a
size = a
coordinate a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
0 Bool -> Bool -> Bool
&& a
coordinate a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
size

-- | Convert a rank to a coordinate.
--
-- Precondition: strides must satisfy the row-major invariant
-- (@strides[k] = sizes[k+1] * strides[k+1]@). All spaces produced by
-- 'rowMajor', 'select', and 'fixDim' satisfy this. A hand-constructed
-- 'AffineRankSpace' with arbitrary strides may produce wrong
-- coordinates without error.
--
-- Throws an error if the rank is outside the affine rank space. Use
-- 'pointOfMaybe' for a total variant.
pointOf :: AffineRankSpace -> Int -> Point
pointOf :: AffineRankSpace -> Int -> [Int]
pointOf AffineRankSpace
space Int
rank =
  case AffineRankSpace -> Int -> Maybe [Int]
pointOfMaybe AffineRankSpace
space Int
rank of
    Just [Int]
point -> [Int]
point
    Maybe [Int]
Nothing -> String -> [Int]
forall a. HasCallStack => String -> a
error String
"pointOf: rank outside affine rank space"

-- | Convert a rank to a coordinate, returning 'Nothing' for ranks
-- outside the affine rank space.
--
-- Precondition: strides must satisfy the row-major invariant
-- (@strides[k] = sizes[k+1] * strides[k+1]@). See 'pointOf'.
pointOfMaybe :: AffineRankSpace -> Int -> Maybe Point
pointOfMaybe :: AffineRankSpace -> Int -> Maybe [Int]
pointOfMaybe AffineRankSpace
space Int
rank
  | Int
rank Int -> [Int] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` AffineRankSpace -> [Int]
ranks AffineRankSpace
space =
      [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
coordinate (AffineRankSpace -> [Int]
strides AffineRankSpace
space) (AffineRankSpace -> [Int]
sizes AffineRankSpace
space))
  | Bool
otherwise = Maybe [Int]
forall a. Maybe a
Nothing
  where
    relativeRank :: Int
relativeRank = Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- AffineRankSpace -> Int
offset AffineRankSpace
space

    coordinate :: Int -> Int -> Int
coordinate Int
stride Int
size =
      (Int
relativeRank Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
stride) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
size

-- | Number of logical points in an affine rank space.
spaceExtent :: AffineRankSpace -> Int
spaceExtent :: AffineRankSpace -> Int
spaceExtent AffineRankSpace
space =
  [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (AffineRankSpace -> [Int]
sizes AffineRankSpace
space)

-- | Enumerate all logical coordinates for a shape.
--
-- Coordinates are scanned in row-major order: the last dimension varies
-- fastest. This is the coordinate grid used by 'ranks'.
points :: Shape -> [Point]
points :: [Int] -> [[Int]]
points [] = [[]]
points (Int
n : [Int]
ns) =
  [ Int
i Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
rest
  | Int
i <- [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1],
    [Int]
rest <- [Int] -> [[Int]]
points [Int]
ns
  ]

-- | Select a strided interval along one dimension.
--
-- The selected dimension remains present with a reduced extent.
-- Preserves the row-major stride invariant, so the result is safe to
-- pass to 'pointOf'. Returns 'Nothing' for an invalid dimension,
-- interval, or step.
select :: AffineRankSpace -> Int -> Int -> Int -> Int -> Maybe AffineRankSpace
select :: AffineRankSpace
-> Int -> Int -> Int -> Int -> Maybe AffineRankSpace
select AffineRankSpace
space Int
dim Int
begin Int
end Int
step = do
  let shp :: [Int]
shp = AffineRankSpace -> [Int]
sizes AffineRankSpace
space
      sts :: [Int]
sts = AffineRankSpace -> [Int]
strides AffineRankSpace
space

  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Int
dim Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0)
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Int
dim Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shp)
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Int
step Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0)

  let extent :: Int
extent = [Int]
shp [Int] -> Int -> Int
forall a. HasCallStack => [a] -> Int -> a
!! Int
dim

  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Int
begin Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0)
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Int
begin Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
extent)
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Int
end Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
begin)
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Int
end Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
extent)

  let newOffset :: Int
newOffset = AffineRankSpace -> Int
offset AffineRankSpace
space Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
begin Int -> Int -> Int
forall a. Num a => a -> a -> a
* ([Int]
sts [Int] -> Int -> Int
forall a. HasCallStack => [a] -> Int -> a
!! Int
dim)
      newSize :: Int
newSize = (Int
end Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
begin Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
step Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
step

  AffineRankSpace -> Maybe AffineRankSpace
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    AffineRankSpace
      { offset :: Int
offset = Int
newOffset,
        sizes :: [Int]
sizes = Int -> Int -> [Int] -> [Int]
forall a. Int -> a -> [a] -> [a]
replace Int
dim Int
newSize [Int]
shp,
        strides :: [Int]
strides = Int -> Int -> [Int] -> [Int]
forall a. Int -> a -> [a] -> [a]
replace Int
dim ([Int]
sts [Int] -> Int -> Int
forall a. HasCallStack => [a] -> Int -> a
!! Int
dim Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
step) [Int]
sts
      }
  where
    replace :: Int -> a -> [a] -> [a]
    replace :: forall a. Int -> a -> [a] -> [a]
replace Int
i a
x [a]
xs = Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
i [a]
xs [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
x] [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [a]
xs

-- | Select one index along a dimension.
--
-- The fixed dimension remains present with extent @1@. Preserves the
-- row-major stride invariant via 'select'.
fixDim :: AffineRankSpace -> Int -> Int -> Maybe AffineRankSpace
fixDim :: AffineRankSpace -> Int -> Int -> Maybe AffineRankSpace
fixDim AffineRankSpace
space Int
dim Int
i = AffineRankSpace
-> Int -> Int -> Int -> Int -> Maybe AffineRankSpace
select AffineRankSpace
space Int
dim Int
i (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
1

-- | Enumerate all ranks in logical coordinate order.
--
-- Unconditional: works for any 'AffineRankSpace'. Ranks are ordered
-- by the row-major scan of the coordinate grid via 'points'.
ranks :: AffineRankSpace -> [Int]
ranks :: AffineRankSpace -> [Int]
ranks AffineRankSpace
space = ([Int] -> Int) -> [[Int]] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (AffineRankSpace -> [Int] -> Int
rankOf AffineRankSpace
space) ([Int] -> [[Int]]
points (AffineRankSpace -> [Int]
sizes AffineRankSpace
space))