-- | Mutable one-dimensional packed bit arrays in the IO monad.

module Data.BitArray.IO
  ( IOBitArray 
  , getBitArrayBounds
  , newBitArray
  , readBit
  , writeBit
  , flipBit
  , unsafeReadBit
  , unsafeWriteBit
  , unsafeFlipBit
  
  , thawBitArray
  , unsafeThawBitArray
  , freezeBitArray
  , unsafeFreezeBitArray
  )
  where

--------------------------------------------------------------------------------

import Data.Word
import Data.Bits

import Data.Array.IO
import Data.Array.Unsafe

import Data.BitArray.Immutable

--------------------------------------------------------------------------------

data IOBitArray = IOA 
  { IOBitArray -> Int
_first :: {-# UNPACK #-} !Int 
  , IOBitArray -> Int
_last  :: {-# UNPACK #-} !Int 
  , IOBitArray -> IOUArray Int Word64
_words :: {-# UNPACK #-} !(IOUArray Int Word64)
  }
  
--------------------------------------------------------------------------------

getBitArrayBounds :: IOBitArray -> IO (Int,Int)
getBitArrayBounds :: IOBitArray -> IO (Int, Int)
getBitArrayBounds (IOA Int
s Int
t IOUArray Int Word64
_) = (Int, Int) -> IO (Int, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
s,Int
t)

newBitArray :: (Int,Int) -> Bool -> IO IOBitArray
newBitArray :: (Int, Int) -> Bool -> IO IOBitArray
newBitArray (Int
s,Int
t) Bool
b = if Int
tInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<Int
s
  then [Char] -> IO IOBitArray
forall a. HasCallStack => [Char] -> a
error [Char]
"IOBitArray/newBitArray: empty range"
  else do
    IOUArray Int Word64
words <- (Int, Int) -> Word64 -> IO (IOUArray Int Word64)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0,Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Word64
w
    IOBitArray -> IO IOBitArray
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Int -> IOUArray Int Word64 -> IOBitArray
IOA Int
s Int
t IOUArray Int Word64
words)    
  where
    k :: Int
k = (Int
tInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
sInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
64) Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
6
    w :: Word64
w = case Bool
b of
      Bool
False -> Word64
0
      Bool
True  -> Word64
0xFFFFFFFFFFFFFFFF
--               fedcba9876543210       

readBit :: IOBitArray -> Int -> IO Bool
readBit :: IOBitArray -> Int -> IO Bool
readBit ar :: IOBitArray
ar@(IOA Int
s Int
t IOUArray Int Word64
_) Int
j = if Int
jInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<Int
s Bool -> Bool -> Bool
|| Int
jInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
t 
  then [Char] -> IO Bool
forall a. HasCallStack => [Char] -> a
error [Char]
"IOBitArray/readBit: index out of range"
  else IOBitArray -> Int -> IO Bool
unsafeReadBit IOBitArray
ar Int
j 

unsafeReadBit :: IOBitArray -> Int -> IO Bool
unsafeReadBit :: IOBitArray -> Int -> IO Bool
unsafeReadBit (IOA Int
s Int
t IOUArray Int Word64
a) Int
j = do
  let (Int
k,Int
l) = Int -> (Int, Int)
ind (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
s)
  Word64
w <- IOUArray Int Word64 -> Int -> IO Word64
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray IOUArray Int Word64
a Int
k
  Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Word64
w Word64 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
`testBit` Int
l)

writeBit :: IOBitArray -> Int -> Bool -> IO ()
writeBit :: IOBitArray -> Int -> Bool -> IO ()
writeBit ar :: IOBitArray
ar@(IOA Int
s Int
t IOUArray Int Word64
_) Int
j Bool
b = if Int
jInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<Int
s Bool -> Bool -> Bool
|| Int
jInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
t 
  then [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error [Char]
"IOBitArray/writeBit: index out of range"
  else IOBitArray -> Int -> Bool -> IO ()
unsafeWriteBit IOBitArray
ar Int
j Bool
b

unsafeWriteBit :: IOBitArray -> Int -> Bool -> IO ()
unsafeWriteBit :: IOBitArray -> Int -> Bool -> IO ()
unsafeWriteBit (IOA Int
s Int
t IOUArray Int Word64
a) Int
j Bool
b = do
  let (Int
k,Int
l) = Int -> (Int, Int)
ind (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
s)
  Word64
w <- IOUArray Int Word64 -> Int -> IO Word64
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray IOUArray Int Word64
a Int
k
  if Bool
b 
    then IOUArray Int Word64 -> Int -> Word64 -> IO ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray IOUArray Int Word64
a Int
k (Word64
w Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`setBit`   Int
l)
    else IOUArray Int Word64 -> Int -> Word64 -> IO ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray IOUArray Int Word64
a Int
k (Word64
w Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`clearBit` Int
l)  
  () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | flips the bit and returns the /old/ value
flipBit :: IOBitArray -> Int -> IO Bool
flipBit :: IOBitArray -> Int -> IO Bool
flipBit ar :: IOBitArray
ar@(IOA Int
s Int
t IOUArray Int Word64
_) Int
j = if Int
jInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<Int
s Bool -> Bool -> Bool
|| Int
jInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
t 
  then [Char] -> IO Bool
forall a. HasCallStack => [Char] -> a
error [Char]
"IOBitArray/flipBit: index out of range"
  else IOBitArray -> Int -> IO Bool
unsafeFlipBit IOBitArray
ar Int
j

unsafeFlipBit :: IOBitArray -> Int -> IO Bool
unsafeFlipBit :: IOBitArray -> Int -> IO Bool
unsafeFlipBit ar :: IOBitArray
ar@(IOA Int
s Int
t IOUArray Int Word64
a) Int
j = do
  let (Int
k,Int
l) = Int -> (Int, Int)
ind (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
s)
  Word64
w <- IOUArray Int Word64 -> Int -> IO Word64
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray IOUArray Int Word64
a Int
k
  let b :: Bool
b = Word64
w Word64 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
`testBit` Int
l
  if Bool
b
    then IOUArray Int Word64 -> Int -> Word64 -> IO ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray IOUArray Int Word64
a Int
k (Word64
w Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`clearBit` Int
l)
    else IOUArray Int Word64 -> Int -> Word64 -> IO ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray IOUArray Int Word64
a Int
k (Word64
w Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`setBit`   Int
l)  
  Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
b
    
--------------------------------------------------------------------------------
    
thawBitArray :: BitArray -> IO IOBitArray
thawBitArray :: BitArray -> IO IOBitArray
thawBitArray (A Int
s Int
t UArray Int Word64
x) = 
  UArray Int Word64 -> IO (IOUArray Int Word64)
forall i (a :: * -> * -> *) e (b :: * -> * -> *) (m :: * -> *).
(Ix i, IArray a e, MArray b e m) =>
a i e -> m (b i e)
thaw UArray Int Word64
x IO (IOUArray Int Word64)
-> (IOUArray Int Word64 -> IO IOBitArray) -> IO IOBitArray
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \IOUArray Int Word64
y -> IOBitArray -> IO IOBitArray
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Int -> IOUArray Int Word64 -> IOBitArray
IOA Int
s Int
t IOUArray Int Word64
y)

unsafeThawBitArray :: BitArray -> IO IOBitArray
unsafeThawBitArray :: BitArray -> IO IOBitArray
unsafeThawBitArray (A Int
s Int
t UArray Int Word64
x) = 
  UArray Int Word64 -> IO (IOUArray Int Word64)
forall i (a :: * -> * -> *) e (b :: * -> * -> *) (m :: * -> *).
(Ix i, IArray a e, MArray b e m) =>
a i e -> m (b i e)
unsafeThaw UArray Int Word64
x IO (IOUArray Int Word64)
-> (IOUArray Int Word64 -> IO IOBitArray) -> IO IOBitArray
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \IOUArray Int Word64
y -> IOBitArray -> IO IOBitArray
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Int -> IOUArray Int Word64 -> IOBitArray
IOA Int
s Int
t IOUArray Int Word64
y)

freezeBitArray :: IOBitArray -> IO BitArray
freezeBitArray :: IOBitArray -> IO BitArray
freezeBitArray (IOA Int
s Int
t IOUArray Int Word64
x) = 
  IOUArray Int Word64 -> IO (UArray Int Word64)
forall i (a :: * -> * -> *) e (m :: * -> *) (b :: * -> * -> *).
(Ix i, MArray a e m, IArray b e) =>
a i e -> m (b i e)
freeze IOUArray Int Word64
x IO (UArray Int Word64)
-> (UArray Int Word64 -> IO BitArray) -> IO BitArray
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \UArray Int Word64
y -> BitArray -> IO BitArray
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Int -> UArray Int Word64 -> BitArray
A Int
s Int
t UArray Int Word64
y)

unsafeFreezeBitArray :: IOBitArray -> IO BitArray
unsafeFreezeBitArray :: IOBitArray -> IO BitArray
unsafeFreezeBitArray (IOA Int
s Int
t IOUArray Int Word64
x) = 
  IOUArray Int Word64 -> IO (UArray Int Word64)
forall i (a :: * -> * -> *) e (m :: * -> *) (b :: * -> * -> *).
(Ix i, MArray a e m, IArray b e) =>
a i e -> m (b i e)
unsafeFreeze IOUArray Int Word64
x IO (UArray Int Word64)
-> (UArray Int Word64 -> IO BitArray) -> IO BitArray
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \UArray Int Word64
y -> BitArray -> IO BitArray
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Int -> UArray Int Word64 -> BitArray
A Int
s Int
t UArray Int Word64
y)

--------------------------------------------------------------------------------